diff --git a/README.md b/README.md
index 8fe83325c94cba1f9d51e6fbe2de518e3603541f..6d8f70c3997c29c9c651be8fc13a2ce108ae0f95 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,13 @@
 ---
-title: Medrag Multi Modal
-emoji: 🏆
-colorFrom: green
-colorTo: yellow
+title: MedRAG Multi-Modal
+emoji: 🩺
+colorFrom: blue
+colorTo: pink
 sdk: streamlit
-sdk_version: 1.40.0
+sdk_version: "1.39.0"
 app_file: app.py
 pinned: false
-short_description: Multi-modal assistant for medical professionals
 ---
+# MedRAG Multi-Modal
 
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+Multi-modal RAG for medical docmain.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb47e6591b88c69e3efd1994cf1fd99b9e1005a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,114 @@
+import streamlit as st
+
+from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
+from medrag_multi_modal.retrieval.text_retrieval import (
+    BM25sRetriever,
+    ContrieverRetriever,
+    MedCPTRetriever,
+    NVEmbed2Retriever,
+)
+
+# Define constants
+ALL_AVAILABLE_MODELS = [
+    "gemini-1.5-flash-latest",
+    "gemini-1.5-pro-latest",
+    "gpt-4o",
+    "gpt-4o-mini",
+]
+
+# Sidebar for configuration settings
+st.sidebar.title("Configuration Settings")
+project_name = st.sidebar.text_input(
+    label="Project Name",
+    value="ml-colabs/medrag-multi-modal",
+    placeholder="wandb project name",
+    help="format: wandb_username/wandb_project_name",
+)
+chunk_dataset_id = st.sidebar.selectbox(
+    label="Chunk Dataset ID",
+    options=["ashwiniai/medrag-text-corpus-chunks"],
+)
+llm_model = st.sidebar.selectbox(
+    label="LLM Model",
+    options=ALL_AVAILABLE_MODELS,
+)
+top_k_chunks_for_query = st.sidebar.slider(
+    label="Top K Chunks for Query",
+    min_value=1,
+    max_value=20,
+    value=5,
+)
+top_k_chunks_for_options = st.sidebar.slider(
+    label="Top K Chunks for Options",
+    min_value=1,
+    max_value=20,
+    value=3,
+)
+rely_only_on_context = st.sidebar.checkbox(
+    label="Rely Only on Context",
+    value=False,
+)
+retriever_type = st.sidebar.selectbox(
+    label="Retriever Type",
+    options=[
+        "",
+        "BM25S",
+        "Contriever",
+        "MedCPT",
+        "NV-Embed-v2",
+    ],
+)
+
+if retriever_type != "":
+
+    llm_model = LLMClient(model_name=llm_model)
+
+    retriever = None
+
+    if retriever_type == "BM25S":
+        retriever = BM25sRetriever.from_index(
+            index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
+        )
+    elif retriever_type == "Contriever":
+        retriever = ContrieverRetriever.from_index(
+            index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
+            chunk_dataset_id=chunk_dataset_id,
+        )
+    elif retriever_type == "MedCPT":
+        retriever = MedCPTRetriever.from_index(
+            index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
+            chunk_dataset_id=chunk_dataset_id,
+        )
+    elif retriever_type == "NV-Embed-v2":
+        retriever = NVEmbed2Retriever.from_index(
+            index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
+            chunk_dataset_id=chunk_dataset_id,
+        )
+
+    medqa_assistant = MedQAAssistant(
+        llm_client=llm_model,
+        retriever=retriever,
+        top_k_chunks_for_query=top_k_chunks_for_query,
+        top_k_chunks_for_options=top_k_chunks_for_options,
+    )
+
+    with st.chat_message("assistant"):
+        st.markdown(
+            """
+Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
+I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
+
+**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
+Please consult a medical professional for any medical advice.
+
+In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
+            """,
+            unsafe_allow_html=True,
+        )
+    query = st.chat_input("Enter your question here")
+    if query:
+        with st.chat_message("user"):
+            st.markdown(query)
+        response = medqa_assistant.predict(query=query)
+        with st.chat_message("assistant"):
+            st.markdown(response.response)
diff --git a/medrag_multi_modal/__init__.py b/medrag_multi_modal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fa49d0b06f6966ca8571b186b6f5df8731fe885
Binary files /dev/null and b/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc differ
diff --git a/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc b/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59c00481373ca79b82b61d60014ecb8a249882d9
Binary files /dev/null and b/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc differ
diff --git a/medrag_multi_modal/__pycache__/cli.cpython-310.pyc b/medrag_multi_modal/__pycache__/cli.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06eb6b04321081a1f5a30045cf0b2659f112a6f7
Binary files /dev/null and b/medrag_multi_modal/__pycache__/cli.cpython-310.pyc differ
diff --git a/medrag_multi_modal/__pycache__/utils.cpython-310.pyc b/medrag_multi_modal/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8445b913c68d810a2bb8f33c3d53a57797a1ce8
Binary files /dev/null and b/medrag_multi_modal/__pycache__/utils.cpython-310.pyc differ
diff --git a/medrag_multi_modal/__pycache__/utils.cpython-39.pyc b/medrag_multi_modal/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40139b2fb48e976667ac4f1310bcaca54af718df
Binary files /dev/null and b/medrag_multi_modal/__pycache__/utils.cpython-39.pyc differ
diff --git a/medrag_multi_modal/assistant/__init__.py b/medrag_multi_modal/assistant/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bef57a938785df32db1cf835ee8b1a0fbc7d276
--- /dev/null
+++ b/medrag_multi_modal/assistant/__init__.py
@@ -0,0 +1,5 @@
+from .figure_annotation import FigureAnnotatorFromPageImage
+from .llm_client import ClientType, LLMClient
+from .medqa_assistant import MedQAAssistant
+
+__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
diff --git a/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35f51f75db22f247bbd5191203330d3da55ac60a
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed6e54de0a2027b22a8047ee6d2df9d23a89b8ea
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9e3e45e700ee14de0853ef095151dd6cb603d44
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..684968207bc589f187fc8a5cefa99875fd4e253e
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9976b496c227a3001cb51628d867abb4492dc394
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7256ed212132d24cc59b084221f8fadc5c11ea6b
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..793cde5aa95a66b382e6e05814385bbd7664c6ed
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc differ
diff --git a/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7c6c3be809cb097ff3154bae6daf11a2ea8d962
Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc differ
diff --git a/medrag_multi_modal/assistant/figure_annotation.py b/medrag_multi_modal/assistant/figure_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb3838004688355f117e922720b2c8558917e0e0
--- /dev/null
+++ b/medrag_multi_modal/assistant/figure_annotation.py
@@ -0,0 +1,147 @@
+import os
+from glob import glob
+from typing import Optional, Union
+
+import cv2
+import weave
+from PIL import Image
+
+from medrag_multi_modal.assistant.llm_client import LLMClient
+from medrag_multi_modal.assistant.schema import FigureAnnotations
+from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file
+
+
+class FigureAnnotatorFromPageImage(weave.Model):
+    """
+    `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
+    figures from a page image of a scientific textbook.
+
+    !!! example "Example Usage"
+        ```python
+        import weave
+        from dotenv import load_dotenv
+
+        from medrag_multi_modal.assistant import (
+            FigureAnnotatorFromPageImage, LLMClient
+        )
+
+        load_dotenv()
+        weave.init(project_name="ml-colabs/medrag-multi-modal")
+        figure_annotator = FigureAnnotatorFromPageImage(
+            figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
+            structured_output_llm_client=LLMClient(model_name="gpt-4o"),
+            image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
+        )
+        annotations = figure_annotator.predict(page_idx=34)
+        ```
+
+    Args:
+        figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
+            from the page image.
+        structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
+            annotations into a structured format.
+        image_artifact_address (Optional[str]): The address of the image artifact containing the
+            page images.
+    """
+
+    figure_extraction_llm_client: LLMClient
+    structured_output_llm_client: LLMClient
+    _artifact_dir: str
+
+    def __init__(
+        self,
+        figure_extraction_llm_client: LLMClient,
+        structured_output_llm_client: LLMClient,
+        image_artifact_address: Optional[str] = None,
+    ):
+        super().__init__(
+            figure_extraction_llm_client=figure_extraction_llm_client,
+            structured_output_llm_client=structured_output_llm_client,
+        )
+        self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
+
+    @weave.op()
+    def annotate_figures(
+        self, page_image: Image.Image
+    ) -> dict[str, Union[Image.Image, str]]:
+        annotation = self.figure_extraction_llm_client.predict(
+            system_prompt="""
+You are an expert in the domain of scientific textbooks, especially medical texts.
+You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
+You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc.
+Then you are to identify the figure IDs associated with each figure in the page image.
+Then, you are to extract only the exact figure descriptions from the page image.
+You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object.
+
+Here are some clues you need to follow:
+1. Figure IDs are unique identifiers for each figure in the page image.
+2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
+3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
+4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
+5. The text in the page image is written in English and is present in a two-column format.
+6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space.
+    You are to carefully identify all the figures in the page image.
+7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side
+    or one above the other.
+8. The figures may or may not have a distinct border against a white background.
+10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is.
+""",
+            user_prompt=[page_image],
+        )
+        return {"page_image": page_image, "annotations": annotation}
+
+    @weave.op
+    def extract_structured_output(self, annotations: str) -> FigureAnnotations:
+        return self.structured_output_llm_client.predict(
+            system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
+            user_prompt=[annotations],
+            schema=FigureAnnotations,
+        )
+
+    @weave.op()
+    def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]:
+        """
+        Predicts figure annotations for a specific page in a document.
+
+        This function retrieves the artifact directory from the given image artifact address,
+        reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
+        to find the specified page index. If the page index matches, it reads the page image
+        and associated figure images, and then uses the `annotate_figures` method to extract
+        figure annotations from the page image. The extracted annotations are then structured
+        using the `extract_structured_output` method and returned as a dictionary.
+
+        Args:
+            page_idx (int): The index of the page to annotate.
+            image_artifact_address (str): The address of the image artifact containing the
+                page images.
+
+        Returns:
+            dict: A dictionary containing the page index as the key and the extracted figure
+                annotations as the value.
+        """
+
+        metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
+        annotations = {}
+        for item in metadata:
+            if item["page_idx"] == page_idx:
+                page_image_file = os.path.join(
+                    self._artifact_dir, f"page{item['page_idx']}.png"
+                )
+                figure_image_files = glob(
+                    os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
+                )
+                if len(figure_image_files) > 0:
+                    page_image = cv2.imread(page_image_file)
+                    page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
+                    page_image = Image.fromarray(page_image)
+                    figure_extracted_annotations = self.annotate_figures(
+                        page_image=page_image
+                    )
+                    figure_extracted_annotations = self.extract_structured_output(
+                        figure_extracted_annotations["annotations"]
+                    ).model_dump()
+                    annotations[item["page_idx"]] = figure_extracted_annotations[
+                        "annotations"
+                    ]
+                break
+        return annotations
diff --git a/medrag_multi_modal/assistant/llm_client.py b/medrag_multi_modal/assistant/llm_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee8ff5a637a5b342ac1689dc85fee903dff64b27
--- /dev/null
+++ b/medrag_multi_modal/assistant/llm_client.py
@@ -0,0 +1,245 @@
+import json
+import os
+from enum import Enum
+from typing import Any, Optional, Union
+
+import instructor
+import weave
+from PIL import Image
+
+from ..utils import base64_encode_image
+
+
+class ClientType(str, Enum):
+    GEMINI = "gemini"
+    MISTRAL = "mistral"
+    OPENAI = "openai"
+
+
+GOOGLE_MODELS = [
+    "gemini-1.0-pro-latest",
+    "gemini-1.0-pro",
+    "gemini-pro",
+    "gemini-1.0-pro-001",
+    "gemini-1.0-pro-vision-latest",
+    "gemini-pro-vision",
+    "gemini-1.5-pro-latest",
+    "gemini-1.5-pro-001",
+    "gemini-1.5-pro-002",
+    "gemini-1.5-pro",
+    "gemini-1.5-pro-exp-0801",
+    "gemini-1.5-pro-exp-0827",
+    "gemini-1.5-flash-latest",
+    "gemini-1.5-flash-001",
+    "gemini-1.5-flash-001-tuning",
+    "gemini-1.5-flash",
+    "gemini-1.5-flash-exp-0827",
+    "gemini-1.5-flash-002",
+    "gemini-1.5-flash-8b",
+    "gemini-1.5-flash-8b-001",
+    "gemini-1.5-flash-8b-latest",
+    "gemini-1.5-flash-8b-exp-0827",
+    "gemini-1.5-flash-8b-exp-0924",
+]
+
+MISTRAL_MODELS = [
+    "ministral-3b-latest",
+    "ministral-8b-latest",
+    "mistral-large-latest",
+    "mistral-small-latest",
+    "codestral-latest",
+    "pixtral-12b-2409",
+    "open-mistral-nemo",
+    "open-codestral-mamba",
+    "open-mistral-7b",
+    "open-mixtral-8x7b",
+    "open-mixtral-8x22b",
+]
+
+OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
+
+
+class LLMClient(weave.Model):
+    """
+    LLMClient is a class that interfaces with different large language model (LLM) providers
+    such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
+    these different APIs and provides a unified interface for making predictions.
+
+    Args:
+        model_name (str): The name of the model to be used for predictions.
+        client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
+            If not provided, it is inferred from the model_name.
+    """
+
+    model_name: str
+    client_type: Optional[ClientType]
+
+    def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
+        if client_type is None:
+            if model_name in GOOGLE_MODELS:
+                client_type = ClientType.GEMINI
+            elif model_name in MISTRAL_MODELS:
+                client_type = ClientType.MISTRAL
+            elif model_name in OPENAI_MODELS:
+                client_type = ClientType.OPENAI
+            else:
+                raise ValueError(f"Invalid model name: {model_name}")
+        super().__init__(model_name=model_name, client_type=client_type)
+
+    @weave.op()
+    def execute_gemini_sdk(
+        self,
+        user_prompt: Union[str, list[str]],
+        system_prompt: Optional[Union[str, list[str]]] = None,
+        schema: Optional[Any] = None,
+    ) -> Union[str, Any]:
+        import google.generativeai as genai
+        from google.generativeai.types import HarmBlockThreshold, HarmCategory
+
+        system_prompt = (
+            [system_prompt] if isinstance(system_prompt, str) else system_prompt
+        )
+        user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
+
+        genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
+        model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
+        generation_config = (
+            None
+            if schema is None
+            else genai.GenerationConfig(
+                response_mime_type="application/json", response_schema=schema
+            )
+        )
+        response = model.generate_content(
+            user_prompt,
+            generation_config=generation_config,
+            # This is necessary in order to answer questions about anatomy, sexual diseases,
+            # medical devices, medicines, etc.
+            safety_settings={
+                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
+                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
+            },
+        )
+        return response.text if schema is None else json.loads(response.text)
+
+    @weave.op()
+    def execute_mistral_sdk(
+        self,
+        user_prompt: Union[str, list[str]],
+        system_prompt: Optional[Union[str, list[str]]] = None,
+        schema: Optional[Any] = None,
+    ) -> Union[str, Any]:
+        from mistralai import Mistral
+
+        system_prompt = (
+            [system_prompt] if isinstance(system_prompt, str) else system_prompt
+        )
+        user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
+        system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
+        user_messages = []
+        for prompt in user_prompt:
+            if isinstance(prompt, Image.Image):
+                user_messages.append(
+                    {
+                        "type": "image_url",
+                        "image_url": base64_encode_image(prompt, "image/png"),
+                    }
+                )
+            else:
+                user_messages.append({"type": "text", "text": prompt})
+        messages = [
+            {"role": "system", "content": system_messages},
+            {"role": "user", "content": user_messages},
+        ]
+
+        client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
+        client = instructor.from_mistral(client) if schema is not None else client
+
+        if schema is None:
+            raise NotImplementedError(
+                "Mistral does not support structured output using a schema"
+            )
+        else:
+            response = client.chat.complete(model=self.model_name, messages=messages)
+            return response.choices[0].message.content
+
+    @weave.op()
+    def execute_openai_sdk(
+        self,
+        user_prompt: Union[str, list[str]],
+        system_prompt: Optional[Union[str, list[str]]] = None,
+        schema: Optional[Any] = None,
+    ) -> Union[str, Any]:
+        from openai import OpenAI
+
+        system_prompt = (
+            [system_prompt] if isinstance(system_prompt, str) else system_prompt
+        )
+        user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
+
+        system_messages = [
+            {"role": "system", "content": prompt} for prompt in system_prompt
+        ]
+        user_messages = []
+        for prompt in user_prompt:
+            if isinstance(prompt, Image.Image):
+                user_messages.append(
+                    {
+                        "type": "image_url",
+                        "image_url": {
+                            "url": base64_encode_image(prompt, "image/png"),
+                        },
+                    },
+                )
+            else:
+                user_messages.append({"type": "text", "text": prompt})
+        messages = system_messages + [{"role": "user", "content": user_messages}]
+
+        client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
+
+        if schema is None:
+            completion = client.chat.completions.create(
+                model=self.model_name, messages=messages
+            )
+            return completion.choices[0].message.content
+
+        completion = weave.op()(client.beta.chat.completions.parse)(
+            model=self.model_name, messages=messages, response_format=schema
+        )
+        return completion.choices[0].message.parsed
+
+    @weave.op()
+    def predict(
+        self,
+        user_prompt: Union[str, list[str]],
+        system_prompt: Optional[Union[str, list[str]]] = None,
+        schema: Optional[Any] = None,
+    ) -> Union[str, Any]:
+        """
+        Predicts the response from a language model based on the provided prompts and schema.
+
+        This function determines the client type and calls the appropriate SDK execution function
+        to get the response from the language model. It supports multiple client types including
+        GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
+        execution function with the provided user and system prompts, and an optional schema.
+
+        Args:
+            user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
+            system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
+            schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
+
+        Returns:
+            Union[str, Any]: The response from the language model, which could be a string or any other type
+            depending on the schema provided.
+
+        Raises:
+            ValueError: If the client type is invalid.
+        """
+        if self.client_type == ClientType.GEMINI:
+            return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
+        elif self.client_type == ClientType.MISTRAL:
+            return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
+        elif self.client_type == ClientType.OPENAI:
+            return self.execute_openai_sdk(user_prompt, system_prompt, schema)
+        else:
+            raise ValueError(f"Invalid client type: {self.client_type}")
diff --git a/medrag_multi_modal/assistant/medqa_assistant.py b/medrag_multi_modal/assistant/medqa_assistant.py
new file mode 100644
index 0000000000000000000000000000000000000000..95cc5e539958e9046aee6f69045f85bb86f1cf37
--- /dev/null
+++ b/medrag_multi_modal/assistant/medqa_assistant.py
@@ -0,0 +1,174 @@
+from typing import Optional
+
+import weave
+
+from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
+from medrag_multi_modal.assistant.llm_client import LLMClient
+from medrag_multi_modal.assistant.schema import (
+    MedQACitation,
+    MedQAMCQResponse,
+    MedQAResponse,
+)
+from medrag_multi_modal.retrieval.common import SimilarityMetric
+from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
+
+
+class MedQAAssistant(weave.Model):
+    """
+    `MedQAAssistant` is a class designed to assist with medical queries by leveraging a
+    language model client, a retriever model, and a figure annotator.
+
+    !!! example "Usage Example"
+        ```python
+        import weave
+        from dotenv import load_dotenv
+
+        from medrag_multi_modal.assistant import (
+            FigureAnnotatorFromPageImage,
+            LLMClient,
+            MedQAAssistant,
+        )
+        from medrag_multi_modal.retrieval import MedCPTRetriever
+
+        load_dotenv()
+        weave.init(project_name="ml-colabs/medrag-multi-modal")
+
+        llm_client = LLMClient(model_name="gemini-1.5-flash")
+
+        retriever=MedCPTRetriever.from_wandb_artifact(
+            chunk_dataset_name="grays-anatomy-chunks:v0",
+            index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
+        )
+
+        figure_annotator=FigureAnnotatorFromPageImage(
+            figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
+            structured_output_llm_client=LLMClient(model_name="gpt-4o"),
+            image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
+        )
+        medqa_assistant = MedQAAssistant(
+            llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
+        )
+        medqa_assistant.predict(query="What is ribosome?")
+        ```
+
+    Args:
+        llm_client (LLMClient): The language model client used to generate responses.
+        retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
+        figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
+        top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
+        top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
+        retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
+    """
+
+    llm_client: LLMClient
+    retriever: weave.Model
+    figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
+    top_k_chunks_for_query: int = 2
+    top_k_chunks_for_options: int = 2
+    rely_only_on_context: bool = True
+    retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
+
+    @weave.op()
+    def retrieve_chunks_for_query(self, query: str) -> list[dict]:
+        retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
+        if not isinstance(self.retriever, BM25sRetriever):
+            retriever_kwargs["metric"] = self.retrieval_similarity_metric
+        return self.retriever.predict(query, **retriever_kwargs)
+
+    @weave.op()
+    def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
+        retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
+        if not isinstance(self.retriever, BM25sRetriever):
+            retriever_kwargs["metric"] = self.retrieval_similarity_metric
+        retrieved_chunks = []
+        for option in options:
+            retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
+        return retrieved_chunks
+
+    @weave.op()
+    def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
+        """
+        Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
+        from a medical document and using a language model to generate the final response.
+
+        This function performs the following steps:
+        1. Retrieves relevant text chunks from the medical document based on the query and any provided options
+           using the retriever model.
+        2. Extracts the text and page indices from the retrieved chunks.
+        3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
+        4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
+           and figure descriptions.
+        5. Uses the language model client to generate a response based on the constructed prompts, either choosing
+           from provided options or generating a free-form response.
+        6. Returns the generated response, which includes the answer and explanation if options were provided.
+
+        The function can operate in two modes:
+        - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
+        - Free response: When no options are provided, it generates a comprehensive response based on the context
+
+        Args:
+            query (str): The medical query to be answered.
+            options (Optional[list[str]]): The list of options to choose from.
+            rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
+
+        Returns:
+            MedQAResponse: The generated response to the query, including source information.
+        """
+        retrieved_chunks = self.retrieve_chunks_for_query(query)
+        options = options or []
+        retrieved_chunks += self.retrieve_chunks_for_options(options)
+
+        retrieved_chunk_texts = []
+        page_indices = set()
+        for chunk in retrieved_chunks:
+            retrieved_chunk_texts.append(chunk["text"])
+            page_indices.add(int(chunk["page_idx"]))
+
+        figure_descriptions = []
+        if self.figure_annotator is not None:
+            for page_idx in page_indices:
+                figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
+                    page_idx
+                ]
+                figure_descriptions += [
+                    item["figure_description"] for item in figure_annotations
+                ]
+
+        system_prompt = """You are an expert in medical science. You are given a question
+and a list of excerpts from various medical documents.
+        """
+        query = f"""# Question
+{query}
+        """
+
+        if len(options) > 0:
+            system_prompt += """\nYou are also given a list of options to choose your answer from.
+You are supposed to choose the best possible option based on the context provided. You should also
+explain your answer to justify why you chose that option.
+"""
+            query += "## Options\n"
+            for option in options:
+                query += f"- {option}\n"
+        else:
+            system_prompt += "\nYou are supposed to answer the question based on the context provided."
+
+        if self.rely_only_on_context:
+            system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
+You are not allowed to use any external knowledge to answer the question.
+"""
+
+        response = self.llm_client.predict(
+            system_prompt=system_prompt,
+            user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
+            schema=MedQAMCQResponse if len(options) > 0 else None,
+        )
+
+        # TODO: Add figure citations
+        # TODO: Add source document name from retrieved chunks as citations
+        citations = []
+        for page_idx in page_indices:
+            citations.append(
+                MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
+            )
+
+        return MedQAResponse(response=response, citations=citations)
diff --git a/medrag_multi_modal/assistant/schema.py b/medrag_multi_modal/assistant/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4ca6cfb280436a70aa04a50e7d48bd6dbc08f1
--- /dev/null
+++ b/medrag_multi_modal/assistant/schema.py
@@ -0,0 +1,27 @@
+from typing import Union
+
+from pydantic import BaseModel
+
+
+class FigureAnnotation(BaseModel):
+    figure_id: str
+    figure_description: str
+
+
+class FigureAnnotations(BaseModel):
+    annotations: list[FigureAnnotation]
+
+
+class MedQAMCQResponse(BaseModel):
+    answer: str
+    explanation: str
+
+
+class MedQACitation(BaseModel):
+    page_number: int
+    document_name: str
+
+
+class MedQAResponse(BaseModel):
+    response: Union[str, MedQAMCQResponse]
+    citations: list[MedQACitation]
diff --git a/medrag_multi_modal/cli.py b/medrag_multi_modal/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..419362447c14f4a3994fd7563818603bd271e24a
--- /dev/null
+++ b/medrag_multi_modal/cli.py
@@ -0,0 +1,68 @@
+import argparse
+import os
+import subprocess
+import sys
+
+
+def main():
+    parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
+    subparsers = parser.add_subparsers(dest="command", required=True)
+
+    # Run subcommand
+    run_parser = subparsers.add_parser("run", help="Run the Streamlit application")
+    run_parser.add_argument(
+        "--port", type=int, default=8501, help="Port to run Streamlit on"
+    )
+
+    # Evaluate subcommand
+    eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests")
+    eval_parser.add_argument(
+        "--test-file",
+        default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"),
+        help="Path to test file",
+    )
+    eval_parser.add_argument(
+        "--test-case",
+        type=str,
+        help="Only run tests which match the given substring expression",
+    )
+    eval_parser.add_argument(
+        "--model-name",
+        type=str,
+        default="gemini-1.5-flash",
+        help="Model name to use for evaluation",
+    )
+
+    args = parser.parse_args()
+
+    if args.command == "run":
+        subprocess.run(
+            [
+                sys.executable,
+                "-m",
+                "streamlit",
+                "run",
+                "app.py",
+                "--server.port",
+                str(args.port),
+            ]
+        )
+
+    elif args.command == "evaluate":
+        test_file = (
+            args.test_file + "::" + args.test_case if args.test_case else args.test_file
+        )
+        cmd = [
+            sys.executable,
+            "-m",
+            "pytest",
+            "-s",
+            test_file,
+            "-v",
+            f"--model-name={args.model_name}",
+        ]
+        subprocess.run(cmd)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/medrag_multi_modal/document_loader/__init__.py b/medrag_multi_modal/document_loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c48cb0307455aa952a0a060aba96d0964d1683a
--- /dev/null
+++ b/medrag_multi_modal/document_loader/__init__.py
@@ -0,0 +1,25 @@
+from .image_loader import (
+    FitzPILImageLoader,
+    MarkerImageLoader,
+    PDF2ImageLoader,
+    PDFPlumberImageLoader,
+    PyMuPDFImageLoader,
+)
+from .text_loader import (
+    MarkerTextLoader,
+    PDFPlumberTextLoader,
+    PyMuPDF4LLMTextLoader,
+    PyPDF2TextLoader,
+)
+
+__all__ = [
+    "PyMuPDF4LLMTextLoader",
+    "PyPDF2TextLoader",
+    "PDFPlumberTextLoader",
+    "MarkerTextLoader",
+    "PDF2ImageLoader",
+    "MarkerImageLoader",
+    "PDFPlumberImageLoader",
+    "PyMuPDFImageLoader",
+    "FitzPILImageLoader",
+]
diff --git a/medrag_multi_modal/document_loader/image_loader/__init__.py b/medrag_multi_modal/document_loader/image_loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e0f43d7477e0d2ab5205f4861d14564c9e8597b
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/__init__.py
@@ -0,0 +1,13 @@
+from .fitzpil_img_loader import FitzPILImageLoader
+from .marker_img_loader import MarkerImageLoader
+from .pdf2image_img_loader import PDF2ImageLoader
+from .pdfplumber_img_loader import PDFPlumberImageLoader
+from .pymupdf_img_loader import PyMuPDFImageLoader
+
+__all__ = [
+    "PDF2ImageLoader",
+    "MarkerImageLoader",
+    "PDFPlumberImageLoader",
+    "PyMuPDFImageLoader",
+    "FitzPILImageLoader",
+]
diff --git a/medrag_multi_modal/document_loader/image_loader/base_img_loader.py b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdc99e99fb6102812b4106e3a5ae6eae73a1836e
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py
@@ -0,0 +1,180 @@
+import asyncio
+import os
+from abc import abstractmethod
+from glob import glob
+from typing import Dict, List, Optional
+
+import huggingface_hub
+import jsonlines
+import rich
+from datasets import (
+    Dataset,
+    Features,
+    Image,
+    Sequence,
+    Value,
+    concatenate_datasets,
+    load_dataset,
+)
+
+from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
+    BaseTextLoader,
+)
+
+
+class BaseImageLoader(BaseTextLoader):
+    def __init__(self, url: str, document_name: str, document_file_path: str):
+        super().__init__(url, document_name, document_file_path)
+
+    @abstractmethod
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, str]:
+        """
+        Abstract method to process a single page of the PDF and extract the image data.
+
+        Overwrite this method in the subclass to provide the actual implementation and
+        processing logic for each page of the PDF using various PDF processing libraries.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted images.
+            **kwargs: Additional keyword arguments that may be used by underlying libraries.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        pass
+
+    def save_as_dataset(
+        self,
+        start_page: int,
+        end_page: int,
+        image_save_dir: str,
+        dataset_repo_id: Optional[str] = None,
+        overwrite_dataset: bool = False,
+    ):
+        features = Features(
+            {
+                "page_image": Image(decode=True),
+                "page_figure_images": Sequence(Image(decode=True)),
+                "document_name": Value(dtype="string"),
+                "page_idx": Value(dtype="int32"),
+            }
+        )
+
+        all_examples = []
+        for page_idx in range(start_page, end_page):
+            page_image_file_paths = glob(
+                os.path.join(image_save_dir, f"page{page_idx}*.png")
+            )
+            if len(page_image_file_paths) > 0:
+                page_image_path = page_image_file_paths[0]
+                figure_image_paths = [
+                    image_file_path
+                    for image_file_path in glob(
+                        os.path.join(image_save_dir, f"page{page_idx}*_fig*.png")
+                    )
+                ]
+
+                example = {
+                    "page_image": page_image_path,
+                    "page_figure_images": figure_image_paths,
+                    "document_name": self.document_name,
+                    "page_idx": page_idx,
+                }
+                all_examples.append(example)
+
+        dataset = Dataset.from_list(all_examples, features=features)
+
+        if dataset_repo_id:
+            if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
+                if not overwrite_dataset:
+                    dataset = concatenate_datasets(
+                        [dataset, load_dataset(dataset_repo_id)["corpus"]]
+                    )
+
+            dataset.push_to_hub(dataset_repo_id, split="corpus")
+
+        return dataset
+
+    def cleanup_image_dir(self, image_save_dir: str = "./images"):
+        for file in os.listdir(image_save_dir):
+            file_path = os.path.join(image_save_dir, file)
+            if os.path.isfile(file_path):
+                os.remove(file_path)
+
+    async def load_data(
+        self,
+        start_page: Optional[int] = None,
+        end_page: Optional[int] = None,
+        dataset_repo_id: Optional[str] = None,
+        overwrite_dataset: bool = False,
+        image_save_dir: str = "./images",
+        exclude_file_extensions: list[str] = [],
+        **kwargs,
+    ) -> List[Dict[str, str]]:
+        """
+        Asynchronously loads images from a PDF file specified by a URL or local file path.
+        The overrided processing abstract method then processes the images,
+        and optionally publishes it to a WandB artifact.
+
+        This function downloads a PDF from a given URL if it does not already exist locally,
+        reads the specified range of pages, scans each page's content to extract images, and
+        returns a list of Page objects containing the images and metadata.
+
+        It uses `PyPDF2` to calculate the number of pages in the PDF and the
+        overriden `extract_page_data` method provides the actual implementation to process
+        each page, extract the image content from the PDF, and convert it to png format.
+        It processes pages concurrently using `asyncio` for efficiency.
+
+        If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
+
+        Args:
+            start_page (Optional[int]): The starting page index (0-based) to process.
+            end_page (Optional[int]): The ending page index (0-based) to process.
+            dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
+            overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
+            image_save_dir (str): The directory to save the extracted images.
+            exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
+            **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
+
+        Returns:
+            Dataset: A HuggingFace dataset containing the processed pages.
+
+        Raises:
+            ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
+        """
+        os.makedirs(image_save_dir, exist_ok=True)
+        start_page, end_page = self.get_page_indices(start_page, end_page)
+        pages = []
+        processed_pages_counter: int = 1
+        total_pages = end_page - start_page
+
+        async def process_page(page_idx):
+            nonlocal processed_pages_counter
+            page_data = await self.extract_page_data(page_idx, image_save_dir, **kwargs)
+            pages.append(page_data)
+            rich.print(
+                f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
+            )
+            processed_pages_counter += 1
+
+        tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)]
+        for task in asyncio.as_completed(tasks):
+            await task
+
+        with jsonlines.open(
+            os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
+        ) as writer:
+            writer.write(pages)
+
+        for file in os.listdir(image_save_dir):
+            if file.endswith(tuple(exclude_file_extensions)):
+                os.remove(os.path.join(image_save_dir, file))
+
+        dataset = self.save_as_dataset(
+            start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset
+        )
+
+        return dataset
diff --git a/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..836c89e78b014df842328c2eaac8b6f20216ea39
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py
@@ -0,0 +1,127 @@
+import io
+import os
+from typing import Any, Dict
+
+import fitz
+from pdf2image.pdf2image import convert_from_path
+from PIL import Image, ImageOps, UnidentifiedImageError
+
+from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
+    BaseImageLoader,
+)
+
+
+class FitzPILImageLoader(BaseImageLoader):
+    """
+    `FitzPILImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
+    loading of pages from a PDF file as images using the fitz and PIL libraries.
+
+    This class provides functionality to extract images from a PDF file using fitz and PIL libraries,
+    and optionally publish these images to a WandB artifact.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = FitzPILImageLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
+        ```
+
+    Args:
+        url (str): The URL of the PDF document.
+        document_name (str): The name of the document.
+        document_file_path (str): The path to the PDF file.
+    """
+
+    def __init__(self, url: str, document_name: str, document_file_path: str):
+        super().__init__(url, document_name, document_file_path)
+
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, Any]:
+        """
+        Extracts a single page from the PDF as an image using fitz and PIL libraries.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted image.
+            **kwargs: Additional keyword arguments that may be used by fitz and PIL.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "image_file_paths": (list) the local file paths where the images are stored.
+        """
+        image_file_paths = []
+
+        pdf_document = fitz.open(self.document_file_path)
+        page = pdf_document.load_page(page_idx)
+
+        images = page.get_images(full=True)
+        for img_idx, image in enumerate(images):
+            xref = image[0]
+            base_image = pdf_document.extract_image(xref)
+            image_bytes = base_image["image"]
+            image_ext = base_image["ext"]
+
+            try:
+                img = Image.open(io.BytesIO(image_bytes))
+
+                if img.mode in ["L"]:
+                    # images in greyscale looks inverted, need to test on other PDFs
+                    img = ImageOps.invert(img)
+
+                if img.mode == "CMYK":
+                    img = img.convert("RGB")
+
+                if image_ext not in ["png", "jpg", "jpeg"]:
+                    image_ext = "png"
+                    image_file_name = f"page{page_idx}_fig{img_idx}.png"
+                    image_file_path = os.path.join(image_save_dir, image_file_name)
+
+                    img.save(image_file_path, format="PNG")
+                else:
+                    image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
+                    image_file_path = os.path.join(image_save_dir, image_file_name)
+
+                    with open(image_file_path, "wb") as image_file:
+                        image_file.write(image_bytes)
+
+                image_file_paths.append(image_file_path)
+
+            except (UnidentifiedImageError, OSError) as e:
+                print(
+                    f"Skipping image at page {page_idx}, fig {img_idx} due to an error: {e}"
+                )
+                continue
+
+        pdf_document.close()
+
+        page_image = convert_from_path(
+            self.document_file_path,
+            first_page=page_idx + 1,
+            last_page=page_idx + 1,
+            **kwargs,
+        )[0]
+        page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
+
+        return {
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+            "image_file_paths": image_file_paths,
+        }
diff --git a/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e66cf0af74985c58200099c50ab103d6c8af6250
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py
@@ -0,0 +1,131 @@
+import os
+from typing import Any, Coroutine, Dict, List
+
+from marker.convert import convert_single_pdf
+from marker.models import load_all_models
+from pdf2image.pdf2image import convert_from_path
+
+from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
+    BaseImageLoader,
+)
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+
+class MarkerImageLoader(BaseImageLoader):
+    """
+    `MarkerImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
+    loading of pages from a PDF file as images using the marker library.
+
+    This class provides functionality to extract images from a PDF file using marker library,
+    and optionally publish these images to a WandB artifact.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = MarkerImageLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
+        ```
+
+    Args:
+        url (str): The URL of the PDF document.
+        document_name (str): The name of the document.
+        document_file_path (str): The path to the PDF file.
+        save_page_image (bool): Whether to additionally save the image of the entire page.
+    """
+
+    def __init__(
+        self,
+        url: str,
+        document_name: str,
+        document_file_path: str,
+        save_page_image: bool = False,
+    ):
+        super().__init__(url, document_name, document_file_path)
+        self.save_page_image = save_page_image
+        self.model_lst = load_all_models()
+
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, Any]:
+        """
+        Extracts a single page from the PDF as an image using marker library.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted image.
+            **kwargs: Additional keyword arguments that may be used by marker.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "image_file_path": (str) the local file path where the image is stored.
+        """
+        _, images, _ = convert_single_pdf(
+            self.document_file_path,
+            self.model_lst,
+            max_pages=1,
+            batch_multiplier=1,
+            start_page=page_idx,
+            ocr_all_pages=True,
+            **kwargs,
+        )
+
+        image_file_paths = []
+        for img_idx, (_, image) in enumerate(images.items()):
+            image_file_name = f"page{page_idx}_fig{img_idx}.png"
+            image_file_path = os.path.join(image_save_dir, image_file_name)
+            image.save(image_file_path, "png")
+            image_file_paths.append(image_file_path)
+
+        page_image = convert_from_path(
+            self.document_file_path,
+            first_page=page_idx + 1,
+            last_page=page_idx + 1,
+            **kwargs,
+        )[0]
+        page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
+
+        return {
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+            "image_file_paths": os.path.join(image_save_dir, "*.png"),
+        }
+
+    def load_data(
+        self,
+        start_page: int | None = None,
+        end_page: int | None = None,
+        wandb_artifact_name: str | None = None,
+        image_save_dir: str = "./images",
+        exclude_file_extensions: list[str] = [],
+        cleanup: bool = False,
+        **kwargs,
+    ) -> Coroutine[Any, Any, List[Dict[str, str]]]:
+        start_page = start_page - 1 if start_page is not None else None
+        end_page = end_page - 1 if end_page is not None else None
+        return super().load_data(
+            start_page,
+            end_page,
+            wandb_artifact_name,
+            image_save_dir,
+            exclude_file_extensions,
+            cleanup,
+            **kwargs,
+        )
diff --git a/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd5abaad407c8e7781275fc30a19a771f902cf74
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py
@@ -0,0 +1,83 @@
+import os
+from typing import Any, Dict
+
+from pdf2image.pdf2image import convert_from_path
+
+from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
+    BaseImageLoader,
+)
+
+
+class PDF2ImageLoader(BaseImageLoader):
+    """
+    `PDF2ImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
+    loading of pages from a PDF file as images using the pdf2image library.
+
+    This class provides functionality to convert specific pages of a PDF document into images
+    and optionally publish these images to a WandB artifact.
+    It is like a snapshot image version of each of the pages from the PDF.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PDF2ImageLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
+        ```
+
+    Args:
+        url (str): The URL of the PDF document.
+        document_name (str): The name of the document.
+        document_file_path (str): The path to the PDF file.
+    """
+
+    def __init__(self, url: str, document_name: str, document_file_path: str):
+        super().__init__(url, document_name, document_file_path)
+
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, Any]:
+        """
+        Extracts a single page from the PDF as an image using pdf2image library.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted image.
+            **kwargs: Additional keyword arguments that may be used by pdf2image.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "image_file_path": (str) the local file path where the image is stored.
+        """
+        image = convert_from_path(
+            self.document_file_path,
+            first_page=page_idx + 1,
+            last_page=page_idx + 1,
+            **kwargs,
+        )[0]
+
+        image_file_name = f"page{page_idx}.png"
+        image_file_path = os.path.join(image_save_dir, image_file_name)
+        image.save(image_file_path)
+
+        return {
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+            "image_file_path": image_file_path,
+        }
diff --git a/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2635071c6b58751c8791c07e0c96877e8210f3c1
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py
@@ -0,0 +1,101 @@
+import os
+from typing import Any, Dict
+
+import pdfplumber
+from pdf2image.pdf2image import convert_from_path
+
+from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
+    BaseImageLoader,
+)
+
+
+class PDFPlumberImageLoader(BaseImageLoader):
+    """
+    `PDFPlumberImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
+    loading of pages from a PDF file as images using the pdfplumber library.
+
+    This class provides functionality to extract images from a PDF file using pdfplumber library,
+    and optionally publish these images to a WandB artifact.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PDFPlumberImageLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
+        ```
+
+    Args:
+        url (str): The URL of the PDF document.
+        document_name (str): The name of the document.
+        document_file_path (str): The path to the PDF file.
+    """
+
+    def __init__(self, url: str, document_name: str, document_file_path: str):
+        super().__init__(url, document_name, document_file_path)
+
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, Any]:
+        """
+        Extracts a single page from the PDF as an image using pdfplumber library.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted image.
+            **kwargs: Additional keyword arguments that may be used by pdfplumber.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "image_file_path": (str) the local file path where the image is stored.
+        """
+        with pdfplumber.open(self.document_file_path) as pdf:
+            page = pdf.pages[page_idx]
+            images = page.images
+
+            image_file_paths = []
+            for img_idx, image in enumerate(images):
+                extracted_image = page.crop(
+                    (
+                        image["x0"],
+                        image["top"],
+                        image["x1"],
+                        image["bottom"],
+                    )
+                ).to_image(resolution=300)
+
+                image_file_name = f"page{page_idx}_fig{img_idx}.png"
+                image_file_path = os.path.join(image_save_dir, image_file_name)
+
+                extracted_image.save(image_file_path, "png")
+                image_file_paths.append(image_file_path)
+
+        page_image = convert_from_path(
+            self.document_file_path,
+            first_page=page_idx + 1,
+            last_page=page_idx + 1,
+            **kwargs,
+        )[0]
+        page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
+
+        return {
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+            "image_file_paths": image_file_paths,
+        }
diff --git a/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..336b8afc01fa421f8b2b7ae4d6beedd3cdf54ace
--- /dev/null
+++ b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py
@@ -0,0 +1,124 @@
+import io
+import os
+from typing import Any, Dict
+
+import fitz
+from pdf2image.pdf2image import convert_from_path
+from PIL import Image
+
+from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
+    BaseImageLoader,
+)
+
+
+class PyMuPDFImageLoader(BaseImageLoader):
+    """
+    `PyMuPDFImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and
+    loading of pages from a PDF file as images using the pymupdf library.
+
+    This class provides functionality to extract images from a PDF file using pymupdf library,
+    and optionally publish these images to a WandB artifact.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PyMuPDFImageLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
+        ```
+
+    Args:
+        url (str): The URL of the PDF document.
+        document_name (str): The name of the document.
+        document_file_path (str): The path to the PDF file.
+    """
+
+    def __init__(self, url: str, document_name: str, document_file_path: str):
+        super().__init__(url, document_name, document_file_path)
+
+    async def extract_page_data(
+        self, page_idx: int, image_save_dir: str, **kwargs
+    ) -> Dict[str, Any]:
+        """
+        Extracts a single page from the PDF as an image using pymupdf library.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            image_save_dir (str): The directory to save the extracted image.
+            **kwargs: Additional keyword arguments that may be used by pymupdf.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "image_file_paths": (list) the local file paths where the images are stored.
+        """
+        image_file_paths = []
+
+        pdf_document = fitz.open(self.document_file_path)
+        page = pdf_document[page_idx]
+
+        images = page.get_images(full=True)
+        for img_idx, image in enumerate(images):
+            xref = image[0]
+            base_image = pdf_document.extract_image(xref)
+            image_bytes = base_image["image"]
+            image_ext = base_image["ext"]
+
+            if image_ext == "jb2":
+                image_ext = "png"
+            elif image_ext == "jpx":
+                image_ext = "jpg"
+
+            image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}"
+            image_file_path = os.path.join(image_save_dir, image_file_name)
+
+            # For JBIG2 and JPEG2000, we need to convert the image
+            if base_image["ext"] in ["jb2", "jpx"]:
+                try:
+                    pix = fitz.Pixmap(image_bytes)
+                    pix.save(image_file_path)
+                except Exception as err_fitz:
+                    print(f"Error processing image with fitz: {err_fitz}")
+                    # Fallback to using PIL for image conversion
+                    try:
+                        img = Image.open(io.BytesIO(image_bytes))
+                        img.save(image_file_path)
+                    except Exception as err_pil:
+                        print(f"Failed to process image with PIL: {err_pil}")
+                        continue  # Skip this image if both methods fail
+            else:
+                with open(image_file_path, "wb") as image_file:
+                    image_file.write(image_bytes)
+
+            image_file_paths.append(image_file_path)
+
+        pdf_document.close()
+
+        page_image = convert_from_path(
+            self.document_file_path,
+            first_page=page_idx + 1,
+            last_page=page_idx + 1,
+            **kwargs,
+        )[0]
+        page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
+
+        return {
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+            "image_file_paths": image_file_paths,
+        }
diff --git a/medrag_multi_modal/document_loader/text_loader/__init__.py b/medrag_multi_modal/document_loader/text_loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d18dbb893992b9196adc9277b3b758d64a2bce
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/__init__.py
@@ -0,0 +1,11 @@
+from .marker_text_loader import MarkerTextLoader
+from .pdfplumber_text_loader import PDFPlumberTextLoader
+from .pymupdf4llm_text_loader import PyMuPDF4LLMTextLoader
+from .pypdf2_text_loader import PyPDF2TextLoader
+
+__all__ = [
+    "PyMuPDF4LLMTextLoader",
+    "PyPDF2TextLoader",
+    "PDFPlumberTextLoader",
+    "MarkerTextLoader",
+]
diff --git a/medrag_multi_modal/document_loader/text_loader/base_text_loader.py b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bc4dc455223eaaf75b61de500bf740d5fe9446
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py
@@ -0,0 +1,185 @@
+import asyncio
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
+
+import huggingface_hub
+import PyPDF2
+from datasets import Dataset, concatenate_datasets, load_dataset
+from firerequests import FireRequests
+from rich.progress import Progress
+
+
+class BaseTextLoader(ABC):
+    """
+    An abstract base class for loading text from a PDF file, processing it into markdown, and optionally publishing it to a Weave dataset.
+
+    This class handles the downloading of a PDF file from a given URL if it does not already exist locally.
+    Subclasses should implement the specific PDF reading, text extraction, and markdown conversion methods.
+
+    The processed pages are finally stored in a list of Page objects, which can be optionally published to a Weave dataset.
+
+    Args:
+        url (str): The URL of the PDF file to download if not present locally.
+        document_name (str): The name of the document for metadata purposes.
+        document_file_path (str): The local file path where the PDF is stored or will be downloaded.
+        metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset.
+    """
+
+    def __init__(
+        self,
+        url: str,
+        document_name: str,
+        document_file_path: str,
+        metadata: Optional[dict[str, Any]] = None,
+    ):
+        self.url = url
+        self.document_name = document_name
+        self.document_file_path = document_file_path
+        self.metadata = metadata or {}
+        if not os.path.exists(self.document_file_path):
+            FireRequests().download(url, filenames=self.document_file_path)
+        with open(self.document_file_path, "rb") as file:
+            pdf_reader = PyPDF2.PdfReader(file)
+            self.page_count = len(pdf_reader.pages)
+
+    def get_page_indices(
+        self, start_page: Optional[int] = None, end_page: Optional[int] = None
+    ) -> tuple[int, int]:
+        """
+        Get the start and end page indices for processing.
+
+        Args:
+            start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
+            end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
+
+        Returns:
+            tuple[int, int]: A tuple containing the start and end page indices.
+        """
+
+        if start_page:
+            if start_page > self.page_count:
+                raise ValueError(
+                    f"Start page {start_page} is greater than the total page count {self.page_count}"
+                )
+        else:
+            start_page = 0
+        if end_page:
+            if end_page > self.page_count:
+                raise ValueError(
+                    f"End page {end_page} is greater than the total page count {self.page_count}"
+                )
+        else:
+            end_page = self.page_count - 1
+        return start_page, end_page
+
+    @abstractmethod
+    async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
+        """
+        Abstract method to process a single page of the PDF and extract the text data.
+
+        Overwrite this method in the subclass to provide the actual implementation and
+        processing logic for each page of the PDF using various PDF processing libraries.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            **kwargs: Additional keyword arguments that may be used by underlying libraries.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        pass
+
+    async def load_data(
+        self,
+        start_page: Optional[int] = None,
+        end_page: Optional[int] = None,
+        exclude_pages: Optional[list[int]] = None,
+        dataset_repo_id: Optional[str] = None,
+        overwrite_dataset: bool = False,
+        **kwargs,
+    ) -> Dataset:
+        """
+        Asynchronously loads text from a PDF file specified by a URL or local file path.
+        The overrided processing abstract method then processes the text into markdown format,
+        and optionally publishes it to a Weave dataset.
+
+        This function downloads a PDF from a given URL if it does not already exist locally,
+        reads the specified range of pages, converts each page's content to markdown, and
+        returns a list of Page objects containing the text and metadata.
+
+        It uses `PyPDF2` to calculate the number of pages in the PDF and the
+        overriden `extract_page_data` method provides the actual implementation to process
+        each page, extract the text from the PDF, and convert it to markdown.
+        It processes pages concurrently using `asyncio` for efficiency.
+
+        If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset.
+
+        Args:
+            start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
+            end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
+            exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing.
+            dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
+            overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
+            **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
+
+        Returns:
+            Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages.
+            Each entry in the dataset will have the following keys and values:
+
+            - "text": (str) the processed page data in markdown format.
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "loader_name": (str) the name of the loader class used to process the page.
+
+        Raises:
+            ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
+        """
+        start_page, end_page = self.get_page_indices(start_page, end_page)
+        pages = []
+        processed_pages_counter: int = 1
+        total_pages = end_page - start_page
+        exclude_pages = exclude_pages or []
+
+        async def process_page(page_idx):
+            nonlocal processed_pages_counter
+            page_data = await self.extract_page_data(page_idx, **kwargs)
+            page_data["loader_name"] = self.__class__.__name__
+            for key, value in self.metadata.items():
+                if key not in page_data:
+                    page_data[key] = value
+            pages.append(page_data)
+            progress.update(
+                task_id,
+                advance=1,
+                description=f"Loading page {page_idx} using {self.__class__.__name__}",
+            )
+            processed_pages_counter += 1
+
+        progress = Progress()
+        with progress:
+            task_id = progress.add_task("Starting...", total=total_pages)
+            tasks = [
+                process_page(page_idx)
+                for page_idx in range(start_page, end_page + 1)
+                if page_idx not in exclude_pages
+            ]
+            for task in asyncio.as_completed(tasks):
+                await task
+
+        pages.sort(key=lambda x: x["page_idx"])
+
+        dataset = Dataset.from_list(pages)
+        if dataset_repo_id:
+            if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
+                print("Dataset already exists")
+                if not overwrite_dataset:
+                    print("Not overwriting dataset")
+                    dataset = concatenate_datasets(
+                        [dataset, load_dataset(dataset_repo_id, split="corpus")]
+                    )
+            dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False)
+
+        return dataset
diff --git a/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a19e3343bcd2d44f1481def4c7ad031838b815
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py
@@ -0,0 +1,89 @@
+import os
+from typing import Dict
+
+from marker.convert import convert_single_pdf
+from marker.models import load_all_models
+
+from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
+    BaseTextLoader,
+)
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+
+class MarkerTextLoader(BaseTextLoader):
+    """
+    A concrete implementation of the BaseTextLoader for loading text from a PDF file
+    using `marker-pdf`, processing it into a structured text format, and optionally publishing
+    it to a Weave dataset.
+
+    This class extends the BaseTextLoader and implements the abstract methods to
+    load and process pages from a PDF file using marker-pdf, which is a pipeline of deep learning models.
+
+    This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
+    It uses marker-pdf to read the PDF and extract structured text from each page. The processed pages are stored
+    in a list of Page objects, which can be optionally published to a Weave dataset.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader import MarkerTextLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = MarkerTextLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
+        ```
+
+    Args:
+        url (str): The URL of the PDF file to download if not present locally.
+        document_name (str): The name of the document for metadata purposes.
+        document_file_path (str): The local file path where the PDF is stored or will be downloaded.
+    """
+
+    async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
+        """
+        Process a single page of the PDF and extract its structured text using marker-pdf.
+
+        Returns:
+            Dict[str, str]: A dictionary with the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "text": (str) the extracted structured text from the page.
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+            - "meta": (dict) the metadata extracted from the page by marker-pdf.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            **kwargs: Additional keyword arguments to be passed to `marker.convert.convert_single_pdf`.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        model_lst = load_all_models()
+
+        text, _, _ = convert_single_pdf(
+            self.document_file_path,
+            model_lst,
+            max_pages=1,
+            batch_multiplier=1,
+            start_page=page_idx,
+            ocr_all_pages=True,
+            **kwargs,
+        )
+
+        return {
+            "text": text,
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+        }
diff --git a/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..337aed66516170baded9a6faa997c7a2e2503319
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py
@@ -0,0 +1,76 @@
+from typing import Dict
+
+import pdfplumber
+
+from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
+    BaseTextLoader,
+)
+
+
+class PDFPlumberTextLoader(BaseTextLoader):
+    """
+    A concrete implementation of the BaseTextLoader for loading text from a PDF file
+    using `pdfplumber`, processing it into a simple text format, and optionally publishing
+    it to a Weave dataset.
+
+    This class extends the BaseTextLoader and implements the abstract methods to
+    load and process pages from a PDF file.
+
+    This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
+    It uses pdfplumber to read the PDF and extract text from each page. The processed pages are stored in a list
+    of Page objects, which can be optionally published to a Weave dataset.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader import PDFPlumberTextLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PDFPlumberTextLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
+        ```
+
+    Args:
+        url (str): The URL of the PDF file to download if not present locally.
+        document_name (str): The name of the document for metadata purposes.
+        document_file_path (str): The local file path where the PDF is stored or will be downloaded.
+    """
+
+    async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
+        """
+        Process a single page of the PDF and extract its text using pdfplumber.
+
+        Returns:
+            Dict[str, str]: A dictionary with the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "text": (str) the extracted text from the page.
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            **kwargs: Additional keyword arguments to be passed to `pdfplumber.Page.extract_text`.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        with pdfplumber.open(self.document_file_path) as pdf:
+            page = pdf.pages[page_idx]
+            text = page.extract_text(**kwargs)
+
+        return {
+            "text": text,
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+        }
diff --git a/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..05493656683fa7378dff415e3861663667f1340a
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py
@@ -0,0 +1,73 @@
+from typing import Dict
+
+import pymupdf4llm
+
+from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
+    BaseTextLoader,
+)
+
+
+class PyMuPDF4LLMTextLoader(BaseTextLoader):
+    """
+    A concrete implementation of the BaseTextLoader for loading text from a PDF file,
+    processing it into markdown using `pymupdf4llm`, and optionally publishing it to a Weave dataset.
+
+    This class extends the BaseTextLoader and implements the abstract methods to load and process pages from a PDF file.
+
+    This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
+    It uses PyPDF2 to read the PDF and pymupdf4llm to convert pages to markdown. The processed pages are stored in a list
+    of Page objects, which can be optionally published to a Weave dataset.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PyMuPDF4LLMTextLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
+        ```
+
+    Args:
+        url (str): The URL of the PDF file to download if not present locally.
+        document_name (str): The name of the document for metadata purposes.
+        document_file_path (str): The local file path where the PDF is stored or will be downloaded.
+    """
+
+    async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
+        """
+        Process a single page of the PDF and convert it to markdown using `pymupdf4llm`.
+
+        Returns:
+            Dict[str, str]: A dictionary with the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "text": (str) the processed page data in markdown format.
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            **kwargs: Additional keyword arguments to be passed to `pymupdf4llm.to_markdown`.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        text = pymupdf4llm.to_markdown(
+            doc=self.document_file_path, pages=[page_idx], show_progress=False, **kwargs
+        )
+        return {
+            "text": text,
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+        }
diff --git a/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..df6cc011e1a93d9c5f8dd1d25d88f41dff928623
--- /dev/null
+++ b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py
@@ -0,0 +1,77 @@
+from typing import Dict
+
+import PyPDF2
+
+from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
+    BaseTextLoader,
+)
+
+
+class PyPDF2TextLoader(BaseTextLoader):
+    """
+    A concrete implementation of the BaseTextLoader for loading text from a PDF file
+    using `PyPDF2`, processing it into a simple text format, and optionally publishing
+    it to a Weave dataset.
+
+    This class extends the BaseTextLoader and implements the abstract methods to
+    load and process pages from a PDF file.
+
+    This class will handle the downloading of a PDF file from a given URL if it does not already exist locally.
+    It uses PyPDF2 to read the PDF and extract text from each page. The processed pages are stored in a list
+    of Page objects, which can be optionally published to a Weave dataset.
+
+    !!! example "Example Usage"
+        ```python
+        import asyncio
+
+        from medrag_multi_modal.document_loader import PyPDF2TextLoader
+
+        URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
+
+        loader = PyPDF2TextLoader(
+            url=URL,
+            document_name="Gray's Anatomy",
+            document_file_path="grays_anatomy.pdf",
+        )
+        dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
+        ```
+
+    Args:
+        url (str): The URL of the PDF file to download if not present locally.
+        document_name (str): The name of the document for metadata purposes.
+        document_file_path (str): The local file path where the PDF is stored or will be downloaded.
+    """
+
+    async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]:
+        """
+        Process a single page of the PDF and extract its text using PyPDF2.
+
+        Returns:
+            Dict[str, str]: A dictionary with the processed page data.
+            The dictionary will have the following keys and values:
+
+            - "text": (str) the extracted text from the page.
+            - "page_idx": (int) the index of the page.
+            - "document_name": (str) the name of the document.
+            - "file_path": (str) the local file path where the PDF is stored.
+            - "file_url": (str) the URL of the PDF file.
+
+        Args:
+            page_idx (int): The index of the page to process.
+            **kwargs: Additional keyword arguments to be passed to `PyPDF2.PdfReader.pages[0].extract_text`.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the processed page data.
+        """
+        with open(self.document_file_path, "rb") as file:
+            pdf_reader = PyPDF2.PdfReader(file)
+            page = pdf_reader.pages[page_idx]
+            text = page.extract_text(**kwargs)
+
+        return {
+            "text": text,
+            "page_idx": page_idx,
+            "document_name": self.document_name,
+            "file_path": self.document_file_path,
+            "file_url": self.url,
+        }
diff --git a/medrag_multi_modal/metrics/__init__.py b/medrag_multi_modal/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b7bf585104b0331b74407570a1d6744c836320
--- /dev/null
+++ b/medrag_multi_modal/metrics/__init__.py
@@ -0,0 +1,3 @@
+from .mmlu import MMLUOptionAccuracy
+
+__all__ = ["MMLUOptionAccuracy"]
diff --git a/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..086528d16161f1a2c58f2d66f19ff84dfc96b4e7
Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc differ
diff --git a/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8dcebbc939fb67e7993da6a8a84696402a47e5a
Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc differ
diff --git a/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e03ba593229697c48a5efb03842f43be32d33618
Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc differ
diff --git a/medrag_multi_modal/metrics/base.py b/medrag_multi_modal/metrics/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16cb9219d1f47115fc497f4ddc4f49d376a7c63
--- /dev/null
+++ b/medrag_multi_modal/metrics/base.py
@@ -0,0 +1,108 @@
+from typing import Optional
+
+import numpy as np
+import weave
+
+
+class BaseAccuracyMetric(weave.Scorer):
+    """
+    BaseAccuracyMetric is a class that extends the
+    [`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
+    to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
+    
+    This class is designed to process a list of score rows, each containing a 
+    'correct' key that indicates whether a particular prediction was correct. 
+    The `summarize` method calculates various statistical measures and metrics 
+    based on this data, including:
+    
+    - True and false counts: The number of true and false predictions.
+    - True and false fractions: The proportion of true and false predictions.
+    - Standard error: The standard error of the mean for the true predictions.
+    - Precision: The ratio of true positive predictions to the total number of 
+      positive predictions.
+    - Recall: The ratio of true positive predictions to the total number of 
+      actual positives.
+    - F1 Score: The harmonic mean of precision and recall, providing a balance 
+      between the two metrics.
+    
+    The `summarize` method returns a dictionary containing these metrics, 
+    allowing for a detailed analysis of the model's performance.
+    
+    Methods:
+        summarize(score_rows: list) -> Optional[dict]:
+            Processes the input score rows to compute and return a dictionary 
+            of accuracy metrics.
+    """
+    @weave.op()
+    def summarize(self, score_rows: list) -> Optional[dict]:
+        """
+        Summarizes the accuracy metrics from a list of score rows.
+
+        This method processes a list of score rows, each containing a 'correct' key
+        that indicates whether a particular prediction was correct. It calculates
+        various statistical measures and metrics based on this data, including:
+
+        - True and false counts: The number of true and false predictions.
+        - True and false fractions: The proportion of true and false predictions.
+        - Standard error: The standard error of the mean for the true predictions.
+        - Precision: The ratio of true positive predictions to the total number of 
+          positive predictions.
+        - Recall: The ratio of true positive predictions to the total number of 
+          actual positives.
+        - F1 Score: The harmonic mean of precision and recall, providing a balance 
+          between the two metrics.
+
+        The method returns a dictionary containing these metrics, allowing for a 
+        detailed analysis of the model's performance.
+
+        Args:
+            score_rows (list): A list of dictionaries, each containing a 'correct' 
+                key with a boolean value indicating the correctness of a prediction.
+
+        Returns:
+            Optional[dict]: A dictionary containing the calculated accuracy metrics, 
+                or None if the input list is empty.
+        """
+        valid_data = [
+            x.get("correct") for x in score_rows if x.get("correct") is not None
+        ]
+        count_true = list(valid_data).count(True)
+        int_data = [int(x) for x in valid_data]
+
+        sample_mean = np.mean(int_data) if int_data else 0
+        sample_variance = np.var(int_data) if int_data else 0
+        sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
+
+        # Calculate precision, recall, and F1 score
+        true_positives = count_true
+        false_positives = len(valid_data) - count_true
+        false_negatives = len(score_rows) - len(valid_data)
+
+        precision = (
+            true_positives / (true_positives + false_positives)
+            if (true_positives + false_positives) > 0
+            else 0
+        )
+        recall = (
+            true_positives / (true_positives + false_negatives)
+            if (true_positives + false_negatives) > 0
+            else 0
+        )
+        f1_score = (
+            (2 * precision * recall) / (precision + recall)
+            if (precision + recall) > 0
+            else 0
+        )
+
+        return {
+            "correct": {
+                "true_count": count_true,
+                "false_count": len(score_rows) - count_true,
+                "true_fraction": float(sample_mean),
+                "false_fraction": 1.0 - float(sample_mean),
+                "stderr": float(sample_error),
+                "precision": precision,
+                "recall": recall,
+                "f1_score": f1_score,
+            }
+        }
diff --git a/medrag_multi_modal/metrics/mmlu.py b/medrag_multi_modal/metrics/mmlu.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e182084fd5cecef8834710d611ae0b5680dfe4b
--- /dev/null
+++ b/medrag_multi_modal/metrics/mmlu.py
@@ -0,0 +1,24 @@
+import weave
+
+from medrag_multi_modal.assistant.schema import MedQAResponse
+from medrag_multi_modal.metrics.base import BaseAccuracyMetric
+
+
+class MMLUOptionAccuracy(BaseAccuracyMetric):
+    """
+    MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`.
+    
+    This class is designed to evaluate the accuracy of a multiple-choice question 
+    response by comparing the provided answer with the correct answer from the 
+    given options. It uses the MedQAResponse schema to extract the response 
+    and checks if it matches the correct answer.
+
+    Methods:
+    --------
+    score(output: MedQAResponse, options: list[str], answer: str) -> dict:
+        Compares the provided answer with the correct answer and returns a 
+        dictionary indicating whether the answer is correct.
+    """
+    @weave.op()
+    def score(self, output: MedQAResponse, options: list[str], answer: str):
+        return {"correct": options[answer] == output.response.answer}
diff --git a/medrag_multi_modal/retrieval/__init__.py b/medrag_multi_modal/retrieval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20f0a3bdfd27fd93ea9681dc031ead5b885c1909
--- /dev/null
+++ b/medrag_multi_modal/retrieval/__init__.py
@@ -0,0 +1,3 @@
+from .colpali_retrieval import CalPaliRetriever
+
+__all__ = ["CalPaliRetriever"]
diff --git a/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4050794c922e39d8c9ff5248802458a642aa88b
Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e34fff250c919944461288b8d3de0985be872809
Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f7a40a45da5990618ca42d6beb0dc6cf691709b
Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/colpali_retrieval.py b/medrag_multi_modal/retrieval/colpali_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..522d964058abda29649115fa732f5a30ae659607
--- /dev/null
+++ b/medrag_multi_modal/retrieval/colpali_retrieval.py
@@ -0,0 +1,255 @@
+import os
+from typing import TYPE_CHECKING, Any, Optional
+
+import weave
+
+if TYPE_CHECKING:
+    from byaldi import RAGMultiModalModel
+
+import wandb
+from PIL import Image
+
+from medrag_multi_modal.utils import get_wandb_artifact
+
+
+class CalPaliRetriever(weave.Model):
+    """
+    CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.
+
+    This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
+    It can be initialized with a pre-trained model or from a specified W&B artifact. The class
+    also provides methods to index new data and to predict/retrieve documents based on a query.
+
+    Attributes:
+        model_name (str): The name of the model to be used for retrieval.
+    """
+
+    model_name: str
+    _docs_retrieval_model: Optional["RAGMultiModalModel"] = None
+    _metadata: Optional[dict] = None
+    _data_artifact_dir: Optional[str] = None
+
+    def __init__(
+        self,
+        model_name: str = "vidore/colpali-v1.2",
+        docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
+        data_artifact_dir: Optional[str] = None,
+        metadata_dataset_name: Optional[str] = None,
+    ):
+        super().__init__(model_name=model_name)
+        from byaldi import RAGMultiModalModel
+
+        self._docs_retrieval_model = (
+            docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
+        )
+        self._data_artifact_dir = data_artifact_dir
+        self._metadata = (
+            [dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
+            if metadata_dataset_name
+            else None
+        )
+
+    def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
+        """
+        Indexes a dataset of documents and saves the index as a Weave artifact.
+
+        This method retrieves a dataset of documents from a Weave artifact using the provided
+        data artifact name. It then indexes the documents using the document retrieval model
+        and assigns the specified index name. The index is stored locally without storing the
+        collection with the index and overwrites any existing index with the same name.
+
+        If a Weave run is active, the method creates a new Weave artifact with the specified
+        index name and type "colpali-index". It adds the local index directory to the artifact
+        and saves it to Weave, including metadata with the provided Weave dataset name.
+
+        !!! example "Indexing Data"
+            First you need to install `Byaldi` library by Answer.ai.
+
+            ```bash
+            uv pip install Byaldi>=0.0.5
+            ```
+
+            Next, you can index the data by running the following code:
+
+            ```python
+            import wandb
+            from medrag_multi_modal.retrieval import CalPaliRetriever
+
+            wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
+            retriever = CalPaliRetriever()
+            retriever.index(
+                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
+                weave_dataset_name="grays-anatomy-images:v0",
+                index_name="grays-anatomy",
+            )
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            data_artifact_name (str): The name of the Weave artifact containing the dataset.
+            weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata.
+            index_name (str): The name to assign to the created index.
+        """
+        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
+        self._docs_retrieval_model.index(
+            input_path=data_artifact_dir,
+            index_name=index_name,
+            store_collection_with_index=False,
+            overwrite=True,
+        )
+        if wandb.run:
+            artifact = wandb.Artifact(
+                name=index_name,
+                type="colpali-index",
+                metadata={"weave_dataset_name": weave_dataset_name},
+            )
+            artifact.add_dir(
+                local_path=os.path.join(".byaldi", index_name), name="index"
+            )
+            artifact.save()
+
+    @classmethod
+    def from_wandb_artifact(
+        cls,
+        index_artifact_name: str,
+        metadata_dataset_name: str,
+        data_artifact_name: str,
+    ):
+        """
+        Creates an instance of the class from Weights & Biases (wandb) artifacts.
+
+        This method retrieves the necessary artifacts from wandb to initialize the
+        ColPaliRetriever. It fetches the index artifact directory and the data artifact
+        directory using the provided artifact names. It then loads the document retrieval
+        model from the index path within the index artifact directory. Finally, it returns
+        an instance of the class initialized with the retrieved document retrieval model,
+        metadata dataset name, and data artifact directory.
+
+        !!! example "Retrieving Documents"
+            First you need to install `Byaldi` library by Answer.ai.
+
+            ```bash
+            uv pip install Byaldi>=0.0.5
+            ```
+
+            Next, you can retrieve the documents by running the following code:
+
+            ```python
+            import weave
+
+            import wandb
+            from medrag_multi_modal.retrieval import CalPaliRetriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = CalPaliRetriever.from_wandb_artifact(
+                index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
+                metadata_dataset_name="grays-anatomy-images:v0",
+                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
+            )
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            index_artifact_name (str): The name of the wandb artifact containing the index.
+            metadata_dataset_name (str): The name of the dataset containing metadata.
+            data_artifact_name (str): The name of the wandb artifact containing the data.
+
+        Returns:
+            An instance of the class initialized with the retrieved document retrieval model,
+            metadata dataset name, and data artifact directory.
+        """
+        from byaldi import RAGMultiModalModel
+
+        index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
+        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
+        docs_retrieval_model = RAGMultiModalModel.from_index(
+            index_path=os.path.join(index_artifact_dir, "index")
+        )
+        return cls(
+            docs_retrieval_model=docs_retrieval_model,
+            metadata_dataset_name=metadata_dataset_name,
+            data_artifact_dir=data_artifact_dir,
+        )
+
+    @weave.op()
+    def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
+        """
+        Predicts and retrieves the top-k most relevant documents/images for a given query
+        using ColPali.
+
+        This function uses the document retrieval model to search for the most relevant
+        documents based on the provided query. It returns a list of dictionaries, each
+        containing the document image, document ID, and the relevance score.
+
+        !!! example "Retrieving Documents"
+            First you need to install `Byaldi` library by Answer.ai.
+
+            ```bash
+            uv pip install Byaldi>=0.0.5
+            ```
+
+            Next, you can retrieve the documents by running the following code:
+
+            ```python
+            import weave
+
+            import wandb
+            from medrag_multi_modal.retrieval import CalPaliRetriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = CalPaliRetriever.from_wandb_artifact(
+                index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
+                metadata_dataset_name="grays-anatomy-images:v0",
+                data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
+            )
+            retriever.predict(
+                query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
+                top_k=3,
+            )
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            query (str): The search query string.
+            top_k (int, optional): The number of top results to retrieve. Defaults to 10.
+
+        Returns:
+            list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
+                - "doc_image" (PIL.Image.Image): The image of the document.
+                - "doc_id" (str): The ID of the document.
+                - "score" (float): The relevance score of the document.
+        """
+        results = self._docs_retrieval_model.search(query=query, k=top_k)
+        retrieved_results = []
+        for result in results:
+            retrieved_results.append(
+                {
+                    "doc_image": Image.open(
+                        os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
+                    ),
+                    "doc_id": result["doc_id"],
+                    "score": result["score"],
+                }
+            )
+        return retrieved_results
diff --git a/medrag_multi_modal/retrieval/common.py b/medrag_multi_modal/retrieval/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a0f1244bb83c595a5f2427e6ecf86b9334bc4c7
--- /dev/null
+++ b/medrag_multi_modal/retrieval/common.py
@@ -0,0 +1,21 @@
+from enum import Enum
+
+
+class SimilarityMetric(Enum):
+    COSINE = "cosine"
+    EUCLIDEAN = "euclidean"
+
+
+def mean_pooling(token_embeddings, mask):
+    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
+    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
+    return sentence_embeddings
+
+
+def argsort_scores(scores: list[float], descending: bool = False):
+    return [
+        {"item": item, "original_index": idx}
+        for idx, item in sorted(
+            list(enumerate(scores)), key=lambda x: x[1], reverse=descending
+        )
+    ]
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__init__.py b/medrag_multi_modal/retrieval/text_retrieval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8482ab713102ca550af488a830af1aacc7fee3
--- /dev/null
+++ b/medrag_multi_modal/retrieval/text_retrieval/__init__.py
@@ -0,0 +1,11 @@
+from .bm25s_retrieval import BM25sRetriever
+from .contriever_retrieval import ContrieverRetriever
+from .medcpt_retrieval import MedCPTRetriever
+from .nv_embed_2 import NVEmbed2Retriever
+
+__all__ = [
+    "BM25sRetriever",
+    "ContrieverRetriever",
+    "MedCPTRetriever",
+    "NVEmbed2Retriever",
+]
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44658e870a50a31acf0f6ec9602623480e9bcb7b
Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e1a2097604c09f736dfeab88bacc4d4d97def59
Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a921a70611df50782b1c7bd48c55fded32f6ab05
Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..087ebe6536a3c552c34f2ea76daff37c66b07a17
Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1db5aae11e9d1204143660983277301eb8299213
Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc differ
diff --git a/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b528262a62c1ae3c962a42877d4cd26ef6d75d
--- /dev/null
+++ b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py
@@ -0,0 +1,238 @@
+import json
+import os
+import shutil
+from typing import Optional, Union
+
+import bm25s
+import huggingface_hub
+import weave
+from bm25s import BM25
+from datasets import Dataset, load_dataset
+from Stemmer import Stemmer
+
+from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface
+
+LANGUAGE_DICT = {
+    "english": "en",
+    "french": "fr",
+    "german": "de",
+}
+
+
+class BM25sRetriever(weave.Model):
+    """
+    `BM25sRetriever` is a class that provides functionality for indexing and
+    retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s).
+
+    Args:
+        language (str): The language of the documents to be indexed and retrieved.
+        use_stemmer (bool): A flag indicating whether to use stemming during tokenization.
+        retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided,
+            a new instance is created.
+    """
+
+    language: Optional[str]
+    use_stemmer: bool = True
+    _retriever: Optional[BM25]
+
+    def __init__(
+        self,
+        language: str = "english",
+        use_stemmer: bool = True,
+        retriever: Optional[BM25] = None,
+    ):
+        super().__init__(language=language, use_stemmer=use_stemmer)
+        self._retriever = retriever or BM25()
+
+    def index(
+        self,
+        chunk_dataset: Union[Dataset, str],
+        index_repo_id: Optional[str] = None,
+        cleanup: bool = True,
+    ):
+        """
+        Indexes a dataset of text chunks using the BM25 algorithm.
+
+        This method retrieves a dataset of text chunks from a specified source, tokenizes
+        the text using the BM25 tokenizer with optional stemming, and indexes the tokenized
+        text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved
+        to disk and optionally logged as a Huggingface artifact.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = BM25sRetriever()
+            retriever.index(
+                chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
+                index_repo_id="geekyrakshit/grays-anatomy-index",
+            )
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+            cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
+        """
+        chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        corpus = [row["text"] for row in chunk_dataset]
+        corpus_tokens = bm25s.tokenize(
+            corpus,
+            stopwords=LANGUAGE_DICT[self.language],
+            stemmer=Stemmer(self.language) if self.use_stemmer else None,
+        )
+        self._retriever.index(corpus_tokens)
+        if index_repo_id:
+            os.makedirs(".huggingface", exist_ok=True)
+            index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1])
+            self._retriever.save(
+                index_save_dir, corpus=[dict(row) for row in chunk_dataset]
+            )
+            commit_type = (
+                "update"
+                if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
+                else "add"
+            )
+            with open(os.path.join(index_save_dir, "config.json"), "w") as config_file:
+                json.dump(
+                    {
+                        "language": self.language,
+                        "use_stemmer": self.use_stemmer,
+                    },
+                    config_file,
+                    indent=4,
+                )
+            save_to_huggingface(
+                index_repo_id,
+                index_save_dir,
+                commit_message=f"{commit_type}: BM25s index",
+            )
+            if cleanup:
+                shutil.rmtree(index_save_dir)
+
+    @classmethod
+    def from_index(cls, index_repo_id: str):
+        """
+        Creates an instance of the class from a Huggingface repository.
+
+        This class method retrieves a BM25 index artifact from a Huggingface repository,
+        downloads the artifact, and loads the BM25 retriever with the index and its
+        associated corpus. The method also extracts metadata from the artifact to
+        initialize the class instance with the appropriate language and stemming
+        settings.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = BM25sRetriever()
+            retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
+            ```
+
+        Args:
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+
+        Returns:
+            An instance of the class initialized with the BM25 retriever and metadata
+            from the artifact.
+        """
+        index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
+        retriever = bm25s.BM25.load(index_dir, load_corpus=True)
+        with open(os.path.join(index_dir, "config.json"), "r") as config_file:
+            config = json.load(config_file)
+        return cls(retriever=retriever, **config)
+
+    @weave.op()
+    def retrieve(self, query: str, top_k: int = 2):
+        """
+        Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm.
+
+        This method tokenizes the input query using the BM25 tokenizer, which takes into
+        account the language-specific stopwords and optional stemming. It then retrieves
+        the top-k most relevant chunks from the BM25 index based on the tokenized query.
+        The results are returned as a list of dictionaries, each containing a chunk and
+        its corresponding relevance score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = BM25sRetriever()
+            retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
+            retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its
+                relevance score.
+        """
+        query_tokens = bm25s.tokenize(
+            query,
+            stopwords=LANGUAGE_DICT[self.language],
+            stemmer=Stemmer(self.language) if self.use_stemmer else None,
+        )
+        results = self._retriever.retrieve(query_tokens, k=top_k)
+        retrieved_chunks = []
+        for chunk, score in zip(
+            results.documents.flatten().tolist(),
+            results.scores.flatten().tolist(),
+        ):
+            retrieved_chunks.append({**chunk, **{"score": score}})
+        return retrieved_chunks
+
+    @weave.op()
+    def predict(self, query: str, top_k: int = 2):
+        """
+        Predicts the top-k most relevant chunks for a given query using the BM25 algorithm.
+
+        This function is a wrapper around the `retrieve` method. It takes an input query string,
+        tokenizes it using the BM25 tokenizer, and retrieves the top-k most relevant chunks from
+        the BM25 index. The results are returned as a list of dictionaries, each containing a chunk
+        and its corresponding relevance score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = BM25sRetriever()
+            retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
+            retrieved_chunks = retriever.predict(query="What are Ribosomes?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        return self.retrieve(query, top_k)
diff --git a/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..77da42e2693828232048723718ca29ae996c788f
--- /dev/null
+++ b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py
@@ -0,0 +1,310 @@
+import json
+import os
+import shutil
+from typing import Optional, Union
+
+import huggingface_hub
+import safetensors
+import safetensors.torch
+import torch
+import torch.nn.functional as F
+import weave
+from datasets import Dataset, load_dataset
+from rich.progress import track
+from transformers import (
+    AutoModel,
+    AutoTokenizer,
+    BertPreTrainedModel,
+    PreTrainedTokenizerFast,
+)
+
+from medrag_multi_modal.retrieval.common import (
+    SimilarityMetric,
+    argsort_scores,
+    mean_pooling,
+)
+from medrag_multi_modal.utils import (
+    fetch_from_huggingface,
+    get_torch_backend,
+    save_to_huggingface,
+)
+
+
+class ContrieverRetriever(weave.Model):
+    """
+    `ContrieverRetriever` is a class to perform retrieval tasks using the Contriever model.
+
+    It provides methods to encode text data into embeddings, index a dataset of text chunks,
+    and retrieve the most relevant chunks for a given query based on similarity metrics.
+
+    Args:
+        model_name (str): The name of the pre-trained model to use for encoding.
+        vector_index (Optional[torch.Tensor]): The tensor containing the vector representations
+            of the indexed chunks.
+        chunk_dataset (Optional[list[dict]]): The weave dataset of text chunks to be indexed.
+    """
+
+    model_name: str
+    _chunk_dataset: Optional[list[dict]]
+    _tokenizer: PreTrainedTokenizerFast
+    _model: BertPreTrainedModel
+    _vector_index: Optional[torch.Tensor]
+
+    def __init__(
+        self,
+        model_name: str = "facebook/contriever",
+        vector_index: Optional[torch.Tensor] = None,
+        chunk_dataset: Optional[list[dict]] = None,
+    ):
+        super().__init__(model_name=model_name)
+        self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+        self._model = AutoModel.from_pretrained(self.model_name).to(get_torch_backend())
+        self._vector_index = vector_index
+        self._chunk_dataset = chunk_dataset
+
+    def encode(self, corpus: list[str], batch_size: int) -> torch.Tensor:
+        embeddings = []
+        iterable = track(
+            range(0, len(corpus), batch_size),
+            description=f"Encoding corpus using {self.model_name}",
+        ) if batch_size > 1 else range(0, len(corpus), batch_size)
+        for idx in iterable:
+            batch = corpus[idx : idx + batch_size]
+            inputs = self._tokenizer(
+                batch, padding=True, truncation=True, return_tensors="pt"
+            ).to(get_torch_backend())
+            with torch.no_grad():
+                outputs = self._model(**inputs)
+                batch_embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
+                embeddings.append(batch_embeddings)
+        embeddings = torch.cat(embeddings, dim=0)
+        return embeddings
+
+    def index(
+        self,
+        chunk_dataset: Union[str, Dataset],
+        index_repo_id: Optional[str] = None,
+        cleanup: bool = True,
+        batch_size: int = 32,
+    ):
+        """
+        Indexes a dataset of text chunks and optionally saves the vector index to a file.
+
+        This method retrieves a dataset of text chunks from a Weave reference, encodes the
+        text chunks into vector representations using the Contriever model, and stores the
+        resulting vector index. If an index name is provided, the vector index is saved to
+        a file in the safetensors format. Additionally, if a Weave run is active, the vector
+        index file is logged as an artifact to Weave.
+
+        !!! example "Example Usage"
+            ```python
+            from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever
+
+            retriever = ContrieverRetriever()
+            retriever.index(
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
+                batch_size=256,
+            )
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+            cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
+            batch_size (int, optional): The batch size to use for encoding the corpus.
+        """
+        self._chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        corpus = [row["text"] for row in self._chunk_dataset]
+        with torch.no_grad():
+            vector_index = self.encode(corpus, batch_size)
+            self._vector_index = vector_index
+            if index_repo_id:
+                index_save_dir = os.path.join(
+                    ".huggingface", index_repo_id.split("/")[-1]
+                )
+                os.makedirs(index_save_dir, exist_ok=True)
+                safetensors.torch.save_file(
+                    {"vector_index": vector_index.cpu()},
+                    os.path.join(index_save_dir, "vector_index.safetensors"),
+                )
+                commit_type = (
+                    "update"
+                    if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
+                    else "add"
+                )
+                with open(
+                    os.path.join(index_save_dir, "config.json"), "w"
+                ) as config_file:
+                    json.dump(
+                        {"model_name": self.model_name},
+                        config_file,
+                        indent=4,
+                    )
+                save_to_huggingface(
+                    index_repo_id,
+                    index_save_dir,
+                    commit_message=f"{commit_type}: Contriever index",
+                )
+                if cleanup:
+                    shutil.rmtree(index_save_dir)
+
+    @classmethod
+    def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str):
+        """
+        Creates an instance of the class from a Weave artifact.
+
+        This method retrieves a vector index and metadata from a Weave artifact stored in
+        Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
+        reference. The vector index is loaded from a safetensors file and moved to the
+        appropriate device (CPU or GPU). The text chunks are converted into a list of
+        dictionaries. The method then returns an instance of the class initialized with
+        the retrieved model name, vector index, and chunk dataset.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever
+
+            load_dotenv()
+            retriever = ContrieverRetriever().from_index(
+                index_repo_id="geekyrakshit/grays-anatomy-index-contriever",
+                chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
+            )
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+
+        Returns:
+            An instance of the class initialized with the retrieved model name, vector index,
+            and chunk dataset.
+        """
+        index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
+        with safetensors.torch.safe_open(
+            os.path.join(index_dir, "vector_index.safetensors"), framework="pt"
+        ) as f:
+            vector_index = f.get_tensor("vector_index")
+        device = torch.device(get_torch_backend())
+        vector_index = vector_index.to(device)
+        chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        with open(os.path.join(index_dir, "config.json"), "r") as config_file:
+            metadata = json.load(config_file)
+        return cls(
+            model_name=metadata["model_name"],
+            vector_index=vector_index,
+            chunk_dataset=chunk_dataset,
+        )
+
+    @weave.op()
+    def retrieve(
+        self,
+        query: str,
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
+
+        This method encodes the input query into an embedding and computes similarity scores between
+        the query embedding and the precomputed vector index. The similarity metric can be either
+        cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
+        are returned as a list of dictionaries, each containing a chunk and its corresponding score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = ContrieverRetriever().from_index(
+                index_repo_id="geekyrakshit/grays-anatomy-index-contriever",
+                chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
+            )
+            retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        query = [query]
+        device = torch.device(get_torch_backend())
+        with torch.no_grad():
+            query_embedding = self.encode(query, batch_size=1).to(device)
+            if metric == SimilarityMetric.EUCLIDEAN:
+                scores = torch.squeeze(query_embedding @ self._vector_index.T)
+            else:
+                scores = F.cosine_similarity(query_embedding, self._vector_index)
+            scores = scores.cpu().numpy().tolist()
+        scores = argsort_scores(scores, descending=True)[:top_k]
+        retrieved_chunks = []
+        for score in scores:
+            retrieved_chunks.append(
+                {
+                    **self._chunk_dataset[score["original_index"]],
+                    **{"score": score["item"]},
+                }
+            )
+        return retrieved_chunks
+
+    @weave.op()
+    def predict(
+        self,
+        query: str,
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Predicts the top-k most relevant chunks for a given query using the specified similarity metric.
+
+        This function is a wrapper around the `retrieve` method. It takes an input query string,
+        retrieves the top-k most relevant chunks from the precomputed vector index based on the
+        specified similarity metric, and returns the results as a list of dictionaries, each containing
+        a chunk and its corresponding relevance score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever
+
+            load_dotenv()
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = ContrieverRetriever().from_index(
+                index_repo_id="geekyrakshit/grays-anatomy-index-contriever",
+                chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
+            )
+            retrieved_chunks = retriever.predict(query="What are Ribosomes?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        return self.retrieve(query, top_k, metric)
diff --git a/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac400dacf329a532e99b3cac7967326b8c3d2c0f
--- /dev/null
+++ b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py
@@ -0,0 +1,335 @@
+import json
+import os
+import shutil
+from typing import Optional, Union
+
+import huggingface_hub
+import safetensors
+import safetensors.torch
+import torch
+import torch.nn.functional as F
+import weave
+from datasets import Dataset, load_dataset
+from rich.progress import track
+from transformers import (
+    AutoModel,
+    AutoTokenizer,
+    BertPreTrainedModel,
+    PreTrainedTokenizerFast,
+)
+
+from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores
+from medrag_multi_modal.utils import (
+    fetch_from_huggingface,
+    get_torch_backend,
+    save_to_huggingface,
+)
+
+
+class MedCPTRetriever(weave.Model):
+    """
+    A class to retrieve relevant text chunks using MedCPT models.
+
+    This class provides methods to index a dataset of text chunks and retrieve the most relevant
+    chunks for a given query using MedCPT models. It uses separate models for encoding queries
+    and articles, and supports both cosine similarity and Euclidean distance as similarity metrics.
+
+    Args:
+        query_encoder_model_name (str): The name of the model used for encoding queries.
+        article_encoder_model_name (str): The name of the model used for encoding articles.
+        chunk_size (Optional[int]): The maximum length of text chunks.
+        vector_index (Optional[torch.Tensor]): The vector index of encoded text chunks.
+        chunk_dataset (Optional[list[dict]]): The dataset of text chunks.
+    """
+
+    query_encoder_model_name: str
+    article_encoder_model_name: str
+    chunk_size: Optional[int]
+    _chunk_dataset: Optional[list[dict]]
+    _query_tokenizer: PreTrainedTokenizerFast
+    _article_tokenizer: PreTrainedTokenizerFast
+    _query_encoder_model: BertPreTrainedModel
+    _article_encoder_model: BertPreTrainedModel
+    _vector_index: Optional[torch.Tensor]
+
+    def __init__(
+        self,
+        query_encoder_model_name: str = "ncbi/MedCPT-Query-Encoder",
+        article_encoder_model_name: str = "ncbi/MedCPT-Article-Encoder",
+        chunk_size: Optional[int] = None,
+        vector_index: Optional[torch.Tensor] = None,
+        chunk_dataset: Optional[list[dict]] = None,
+    ):
+        super().__init__(
+            query_encoder_model_name=query_encoder_model_name,
+            article_encoder_model_name=article_encoder_model_name,
+            chunk_size=chunk_size,
+        )
+        self._query_tokenizer = AutoTokenizer.from_pretrained(
+            self.query_encoder_model_name
+        )
+        self._article_tokenizer = AutoTokenizer.from_pretrained(
+            self.article_encoder_model_name
+        )
+        self._query_encoder_model = AutoModel.from_pretrained(
+            self.query_encoder_model_name
+        ).to(get_torch_backend())
+        self._article_encoder_model = AutoModel.from_pretrained(
+            self.article_encoder_model_name
+        ).to(get_torch_backend())
+        self._chunk_dataset = chunk_dataset
+        self._vector_index = vector_index
+
+    def index(
+        self,
+        chunk_dataset: Union[str, Dataset],
+        index_repo_id: Optional[str] = None,
+        cleanup: bool = True,
+        batch_size: int = 32,
+    ):
+        """
+        Indexes a dataset of text chunks using the MedCPT model and optionally saves the vector index.
+
+        This method retrieves a dataset of text chunks from a specified source, encodes the text
+        chunks into vector representations using the article encoder model, and stores the
+        resulting vector index. If an `index_repo_id` is provided, the vector index is saved
+        to disk in the safetensors format and optionally logged as a Huggingface artifact.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from dotenv import load_dotenv
+
+            from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever
+
+            load_dotenv()
+            retriever = MedCPTRetriever()
+            retriever.index(
+                chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
+                index_repo_id="geekyrakshit/grays-anatomy-index-medcpt",
+            )
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+            cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
+            batch_size (int, optional): The batch size to use for encoding the corpus.
+
+        """
+        self._chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        corpus = [row["text"] for row in self._chunk_dataset]
+        vector_indices = []
+        with torch.no_grad():
+            for idx in track(
+                range(0, len(corpus), batch_size),
+                description="Encoding corpus using MedCPT",
+            ):
+                batch = corpus[idx : idx + batch_size]
+                encoded = self._article_tokenizer(
+                    batch,
+                    truncation=True,
+                    padding=True,
+                    return_tensors="pt",
+                    max_length=self.chunk_size,
+                ).to(get_torch_backend())
+                batch_vectors = (
+                    self._article_encoder_model(**encoded)
+                    .last_hidden_state[:, 0, :]
+                    .contiguous()
+                )
+                vector_indices.append(batch_vectors)
+
+            vector_index = torch.cat(vector_indices, dim=0)
+            self._vector_index = vector_index
+            if index_repo_id:
+                index_save_dir = os.path.join(
+                    ".huggingface", index_repo_id.split("/")[-1]
+                )
+                os.makedirs(index_save_dir, exist_ok=True)
+                safetensors.torch.save_file(
+                    {"vector_index": self._vector_index.cpu()},
+                    os.path.join(index_save_dir, "vector_index.safetensors"),
+                )
+                commit_type = (
+                    "update"
+                    if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
+                    else "add"
+                )
+                with open(
+                    os.path.join(index_save_dir, "config.json"), "w"
+                ) as config_file:
+                    json.dump(
+                        {
+                            "query_encoder_model_name": self.query_encoder_model_name,
+                            "article_encoder_model_name": self.article_encoder_model_name,
+                            "chunk_size": self.chunk_size,
+                        },
+                        config_file,
+                        indent=4,
+                    )
+                save_to_huggingface(
+                    index_repo_id,
+                    index_save_dir,
+                    commit_message=f"{commit_type}: Contriever index",
+                )
+                if cleanup:
+                    shutil.rmtree(index_save_dir)
+
+    @classmethod
+    def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str):
+        """
+        Creates an instance of the class from a Huggingface repository.
+
+        This method retrieves a vector index and metadata from a Huggingface repository.
+        It also retrieves a dataset of text chunks from the specified source. The vector
+        index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU).
+        The method then returns an instance of the class initialized with the retrieved
+        model names, vector index, and chunk dataset.
+
+        !!! example "Example Usage"
+            ```python
+            from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever
+
+            retriever = MedCPTRetriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+
+        Returns:
+            An instance of the class initialized with the retrieved model name, vector index, and chunk dataset.
+        """
+        index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
+        with safetensors.torch.safe_open(
+            os.path.join(index_dir, "vector_index.safetensors"), framework="pt"
+        ) as f:
+            vector_index = f.get_tensor("vector_index")
+        device = torch.device(get_torch_backend())
+        vector_index = vector_index.to(device)
+        with open(os.path.join(index_dir, "config.json"), "r") as config_file:
+            metadata = json.load(config_file)
+        chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        return cls(
+            query_encoder_model_name=metadata["query_encoder_model_name"],
+            article_encoder_model_name=metadata["article_encoder_model_name"],
+            chunk_size=metadata["chunk_size"],
+            vector_index=vector_index,
+            chunk_dataset=chunk_dataset,
+        )
+
+    @weave.op()
+    def retrieve(
+        self,
+        query: str,
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
+
+        This method encodes the input query into an embedding and computes similarity scores between
+        the query embedding and the precomputed vector index. The similarity metric can be either
+        cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
+        are returned as a list of dictionaries, each containing a chunk and its corresponding score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = MedCPTRetriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            retriever.retrieve(query="What is ribosome?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        query = [query]
+        device = torch.device(get_torch_backend())
+        with torch.no_grad():
+            encoded = self._query_tokenizer(
+                query,
+                truncation=True,
+                padding=True,
+                return_tensors="pt",
+            ).to(device)
+            query_embedding = self._query_encoder_model(**encoded).last_hidden_state[
+                :, 0, :
+            ]
+            query_embedding = query_embedding.to(device)
+            if metric == SimilarityMetric.EUCLIDEAN:
+                scores = torch.squeeze(query_embedding @ self._vector_index.T)
+            else:
+                scores = F.cosine_similarity(query_embedding, self._vector_index)
+            scores = scores.cpu().numpy().tolist()
+        scores = argsort_scores(scores, descending=True)[:top_k]
+        retrieved_chunks = []
+        for score in scores:
+            retrieved_chunks.append(
+                {
+                    **self._chunk_dataset[score["original_index"]],
+                    **{"score": score["item"]},
+                }
+            )
+        return retrieved_chunks
+
+    @weave.op()
+    def predict(
+        self,
+        query: str,
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Predicts the most relevant chunks for a given query.
+
+        This function uses the `retrieve` method to find the top-k relevant chunks
+        from the dataset based on the input query. It allows specifying the number
+        of top relevant chunks to retrieve and the similarity metric to use for scoring.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = MedCPTRetriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            retriever.predict(query="What is ribosome?")
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        return self.retrieve(query, top_k, metric)
diff --git a/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e61883f8e801f105a2f7e039e3ed16b89dd2e9
--- /dev/null
+++ b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py
@@ -0,0 +1,332 @@
+import json
+import os
+import shutil
+from typing import Optional, Union
+
+import huggingface_hub
+import safetensors
+import torch
+import torch.nn.functional as F
+import weave
+from datasets import Dataset, load_dataset
+from rich.progress import track
+from sentence_transformers import SentenceTransformer
+
+from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores
+from medrag_multi_modal.utils import (
+    fetch_from_huggingface,
+    get_torch_backend,
+    save_to_huggingface,
+)
+
+
+class NVEmbed2Retriever(weave.Model):
+    """
+    `NVEmbed2Retriever` is a class for retrieving relevant text chunks from a dataset using the
+    [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2) model.
+
+    This class leverages the SentenceTransformer model to encode text chunks into vector representations and
+    performs similarity-based retrieval. It supports indexing a dataset of text chunks, saving the vector index,
+    and retrieving the most relevant chunks for a given query.
+
+    Args:
+        model_name (str): The name of the pre-trained model to use for encoding.
+        vector_index (Optional[torch.Tensor]): The tensor containing the vector representations of the indexed chunks.
+        chunk_dataset (Optional[list[dict]]): The dataset of text chunks to be indexed.
+    """
+
+    model_name: str
+    _chunk_dataset: Optional[list[dict]]
+    _model: SentenceTransformer
+    _vector_index: Optional[torch.Tensor]
+
+    def __init__(
+        self,
+        model_name: str = "nvidia/NV-Embed-v2",
+        vector_index: Optional[torch.Tensor] = None,
+        chunk_dataset: Optional[list[dict]] = None,
+    ):
+        super().__init__(model_name=model_name)
+        self._model = SentenceTransformer(
+            self.model_name,
+            trust_remote_code=True,
+            model_kwargs={"torch_dtype": torch.float16},
+            device=get_torch_backend(),
+        )
+        self._model.max_seq_length = 32768
+        self._model.tokenizer.padding_side = "right"
+        self._vector_index = vector_index
+        self._chunk_dataset = chunk_dataset
+
+    def add_eos(self, input_examples):
+        input_examples = [
+            input_example + self._model.tokenizer.eos_token
+            for input_example in input_examples
+        ]
+        return input_examples
+
+    def index(
+        self,
+        chunk_dataset: Union[str, Dataset],
+        index_repo_id: Optional[str] = None,
+        cleanup: bool = True,
+        batch_size: int = 8,
+    ):
+        """
+        Indexes a dataset of text chunks and optionally saves the vector index to a Huggingface repository.
+
+        This method retrieves a dataset of text chunks from a specified source, encodes the
+        text chunks into vector representations using the NV-Embed-v2 model, and stores the
+        resulting vector index. If an index repository ID is provided, the vector index is saved to
+        a file in the safetensors format within the specified Huggingface repository.
+
+        !!! example "Example Usage"
+            ```python
+            from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever
+
+            retriever = NVEmbed2Retriever()
+            retriever.index(
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
+            )
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+            cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
+            batch_size (int, optional): The batch size to use for encoding the corpus.
+        """
+        self._chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        corpus = [row["text"] for row in self._chunk_dataset]
+        vector_indices = []
+
+        for idx in track(
+            range(0, len(corpus), batch_size),
+            description="Encoding corpus using NV-Embed-v2",
+        ):
+            batch = corpus[idx : idx + batch_size]
+            batch_embeddings = self._model.encode(
+                self.add_eos(batch), batch_size=len(batch), normalize_embeddings=True
+            )
+            vector_indices.append(torch.tensor(batch_embeddings))
+
+        self._vector_index = torch.cat(vector_indices, dim=0)
+        with torch.no_grad():
+            if index_repo_id:
+                index_save_dir = os.path.join(
+                    ".huggingface", index_repo_id.split("/")[-1]
+                )
+                os.makedirs(index_save_dir, exist_ok=True)
+                safetensors.torch.save_file(
+                    {"vector_index": self._vector_index.cpu()},
+                    os.path.join(index_save_dir, "vector_index.safetensors"),
+                )
+                commit_type = (
+                    "update"
+                    if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
+                    else "add"
+                )
+                with open(
+                    os.path.join(index_save_dir, "config.json"), "w"
+                ) as config_file:
+                    json.dump(
+                        {"model_name": self.model_name},
+                        config_file,
+                        indent=4,
+                    )
+                save_to_huggingface(
+                    index_repo_id,
+                    index_save_dir,
+                    commit_message=f"{commit_type}: Contriever index",
+                )
+                if cleanup:
+                    shutil.rmtree(index_save_dir)
+
+    @classmethod
+    def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str):
+        """
+        Creates an instance of the class from a Huggingface repository.
+
+        This method retrieves a vector index and metadata from a Huggingface repository. It also retrieves a dataset of text chunks from a Huggingface dataset repository. The vector index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). The text chunks are converted into a list of dictionaries. The method then returns an instance of the class initialized with the retrieved model name, vector index, and chunk dataset.
+        Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
+        reference. The vector index is loaded from a safetensors file and moved to the
+        appropriate device (CPU or GPU). The text chunks are converted into a list of
+        dictionaries. The method then returns an instance of the class initialized with
+        the retrieved model name, vector index, and chunk dataset.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever
+
+            retriever = NVEmbed2Retriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
+                dataset repository name or a dataset object can be provided.
+            index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
+
+        Returns:
+            An instance of the class initialized with the retrieved model name, vector index,
+            and chunk dataset.
+        """
+        index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
+        with safetensors.torch.safe_open(
+            os.path.join(index_dir, "vector_index.safetensors"), framework="pt"
+        ) as f:
+            vector_index = f.get_tensor("vector_index")
+        device = torch.device(get_torch_backend())
+        vector_index = vector_index.to(device)
+        chunk_dataset = (
+            load_dataset(chunk_dataset, split="chunks")
+            if isinstance(chunk_dataset, str)
+            else chunk_dataset
+        )
+        with open(os.path.join(index_dir, "config.json"), "r") as config_file:
+            metadata = json.load(config_file)
+        return cls(
+            model_name=metadata["model_name"],
+            vector_index=vector_index,
+            chunk_dataset=chunk_dataset,
+        )
+
+    @weave.op()
+    def retrieve(
+        self,
+        query: list[str],
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
+
+        This method encodes the input query into an embedding and computes similarity scores between
+        the query embedding and the precomputed vector index. The similarity metric can be either
+        cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
+        are returned as a list of dictionaries, each containing a chunk and its corresponding score.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = NVEmbed2Retriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            retriever.retrieve(query="What is ribosome?")
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            query (list[str]): The input query strings to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        device = torch.device(get_torch_backend())
+        with torch.no_grad():
+            query_embedding = self._model.encode(
+                self.add_eos(query), normalize_embeddings=True
+            )
+            query_embedding = torch.from_numpy(query_embedding).to(device)
+            if metric == SimilarityMetric.EUCLIDEAN:
+                scores = torch.squeeze(query_embedding @ self._vector_index.T)
+            else:
+                scores = F.cosine_similarity(query_embedding, self._vector_index)
+            scores = scores.cpu().numpy().tolist()
+        scores = argsort_scores(scores, descending=True)[:top_k]
+        retrieved_chunks = []
+        for score in scores:
+            retrieved_chunks.append(
+                {
+                    **self._chunk_dataset[score["original_index"]],
+                    **{"score": score["item"]},
+                }
+            )
+        return retrieved_chunks
+
+    @weave.op()
+    def predict(
+        self,
+        query: str,
+        top_k: int = 2,
+        metric: SimilarityMetric = SimilarityMetric.COSINE,
+    ):
+        """
+        Predicts the top-k most relevant chunks for a given query using the specified similarity metric.
+
+        This method formats the input query string by prepending an instruction prompt and then calls the
+        `retrieve` method to get the most relevant chunks. The similarity metric can be either cosine similarity
+        or Euclidean distance. The top-k chunks with the highest similarity scores are returned.
+
+        !!! example "Example Usage"
+            ```python
+            import weave
+            from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever
+
+            weave.init(project_name="ml-colabs/medrag-multi-modal")
+            retriever = NVEmbed2Retriever.from_index(
+                index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
+                chunk_dataset="ashwiniai/medrag-text-corpus-chunks",
+            )
+            retriever.predict(query="What is ribosome?")
+            ```
+
+        ??? note "Optional Speedup using Flash Attention"
+            If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
+            installing the `flash-attn` package.
+
+            ```bash
+            uv pip install flash-attn --no-build-isolation
+            ```
+
+        Args:
+            query (str): The input query string to search for relevant chunks.
+            top_k (int, optional): The number of top relevant chunks to retrieve.
+            metric (SimilarityMetric, optional): The similarity metric to use for scoring.
+
+        Returns:
+            list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
+        """
+        query = [
+            f"""Instruct: Given a question, retrieve passages that answer the question
+Query: {query}"""
+        ]
+        return self.retrieve(query, top_k, metric)
diff --git a/medrag_multi_modal/semantic_chunking.py b/medrag_multi_modal/semantic_chunking.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b82b3e7c40f9b80cd127a1646c71330f739457
--- /dev/null
+++ b/medrag_multi_modal/semantic_chunking.py
@@ -0,0 +1,135 @@
+import asyncio
+from typing import Callable, Optional, Union
+
+import huggingface_hub
+import semchunk
+import tiktoken
+import tokenizers
+from datasets import Dataset, concatenate_datasets, load_dataset
+from rich.progress import track
+from transformers import PreTrainedTokenizer
+
+TOKENIZER_OR_TOKEN_COUNTER = Union[
+    str,
+    tiktoken.Encoding,
+    PreTrainedTokenizer,
+    tokenizers.Tokenizer,
+    Callable[[str], int],
+]
+
+
+class SemanticChunker:
+    """
+    SemanticChunker is a class that chunks documents into smaller segments and
+    publishes them as datasets.
+
+    This class uses the `semchunk` library to break down large documents into
+    smaller, manageable chunks based on a specified tokenizer or token counter.
+    This is particularly useful for processing large text datasets where
+    smaller segments are needed for analysis or other operations.
+
+    !!! example "Example Usage"
+        ```python
+        from medrag_multi_modal.semantic_chunking import SemanticChunker
+
+
+        chunker = SemanticChunker(chunk_size=256)
+        chunker.chunk(
+            document_dataset="geekyrakshit/grays-anatomy-test",
+            chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test",
+        )
+        ```
+
+    Args:
+        tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or
+            token counter to be used for chunking.
+        chunk_size (Optional[int]): The size of each chunk. If not specified, the
+            default chunk size from `semchunk` will be used.
+        max_token_chars (Optional[int]): The maximum number of characters per token.
+            If not specified, the default value from `semchunk` will be used.
+        memoize (bool): Whether to memoize the chunking process for efficiency.
+            Default is True.
+    """
+
+    def __init__(
+        self,
+        tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base",
+        chunk_size: Optional[int] = None,
+        max_token_chars: Optional[int] = None,
+        memoize: bool = True,
+    ) -> None:
+        self.chunker = semchunk.chunkerify(
+            tokenizer_or_token_counter,
+            chunk_size=chunk_size,
+            max_token_chars=max_token_chars,
+            memoize=memoize,
+        )
+
+    def chunk(
+        self,
+        document_dataset: Union[Dataset, str],
+        chunk_dataset_repo_id: Optional[str] = None,
+        overwrite_dataset: bool = False,
+    ) -> Dataset:
+        """
+        Chunks a document dataset into smaller segments and publishes them as a new dataset.
+
+        This function takes a document dataset, either as a HuggingFace Dataset object or a string
+        representing the dataset repository ID, and chunks the documents into smaller segments using
+        the specified chunker. The resulting chunks are then optionally published to a HuggingFace
+        dataset repository.
+
+        Args:
+            document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either
+                a HuggingFace Dataset object or a string representing the dataset repository ID.
+            chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish
+                the chunks to, if provided. Defaults to None.
+            overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
+
+        Returns:
+            Dataset: A HuggingFace Dataset object containing the chunks.
+        """
+        document_dataset = (
+            load_dataset(document_dataset, split="corpus")
+            if isinstance(document_dataset, str)
+            else document_dataset
+        ).to_list()
+
+        chunks = []
+
+        async def process_document(idx, document):
+            document_chunks = self.chunker.chunk(str(document["text"]))
+            for chunk in document_chunks:
+                chunk_dict = {"document_idx": idx, "text": chunk}
+                for key, value in document.items():
+                    if key not in chunk_dict:
+                        chunk_dict[key] = value
+                chunks.append(chunk_dict)
+
+        async def process_all_documents():
+            tasks = []
+            for idx, document in track(
+                enumerate(document_dataset),
+                total=len(document_dataset),
+                description="Chunking documents",
+            ):
+                tasks.append(process_document(idx, document))
+            await asyncio.gather(*tasks)
+
+        asyncio.run(process_all_documents())
+
+        chunks.sort(key=lambda x: x["document_idx"])
+
+        dataset = Dataset.from_list(chunks)
+        if chunk_dataset_repo_id:
+            if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"):
+                if not overwrite_dataset:
+                    dataset = concatenate_datasets(
+                        [
+                            dataset,
+                            load_dataset(chunk_dataset_repo_id, split="chunks"),
+                        ]
+                    )
+            dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks")
+
+        return dataset
diff --git a/medrag_multi_modal/utils.py b/medrag_multi_modal/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b6237a4879199653c77235eb3ee9b0d32c764a
--- /dev/null
+++ b/medrag_multi_modal/utils.py
@@ -0,0 +1,86 @@
+import base64
+import io
+
+import jsonlines
+import torch
+import wandb
+from huggingface_hub import HfApi
+from PIL import Image
+
+
+def get_wandb_artifact(
+    artifact_name: str,
+    artifact_type: str,
+    get_metadata: bool = False,
+) -> str:
+    if wandb.run:
+        artifact = wandb.use_artifact(artifact_name, type=artifact_type)
+        artifact_dir = artifact.download()
+    else:
+        api = wandb.Api()
+        artifact = api.artifact(artifact_name)
+        artifact_dir = artifact.download()
+    if get_metadata:
+        return artifact_dir, artifact.metadata
+    return artifact_dir
+
+
+def get_torch_backend():
+    if torch.cuda.is_available():
+        if torch.backends.cuda.is_built():
+            return "cuda"
+    if torch.backends.mps.is_available():
+        if torch.backends.mps.is_built():
+            return "mps"
+        return "cpu"
+    return "cpu"
+
+
+def base64_encode_image(image: Image.Image, mimetype: str) -> str:
+    image.load()
+    if image.mode not in ("RGB", "RGBA"):
+        image = image.convert("RGB")
+    byte_arr = io.BytesIO()
+    image.save(byte_arr, format="PNG")
+    encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
+    encoded_string = f"data:{mimetype};base64,{encoded_string}"
+    return str(encoded_string)
+
+
+def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
+    with jsonlines.open(file_path) as reader:
+        for obj in reader:
+            return obj
+
+
+def save_to_huggingface(
+    repo_id: str, local_dir: str, commit_message: str, private: bool = False
+):
+    api = HfApi()
+    repo_url = api.create_repo(
+        repo_id=repo_id,
+        token=api.token,
+        private=private,
+        repo_type="model",
+        exist_ok=True,
+    )
+    repo_id = repo_url.repo_id
+    api.upload_folder(
+        repo_id=repo_id,
+        commit_message=commit_message,
+        token=api.token,
+        folder_path=local_dir,
+        repo_type=repo_url.repo_type,
+    )
+
+
+def fetch_from_huggingface(repo_id: str, local_dir: str) -> str:
+    api = HfApi()
+    repo_url = api.repo_info(repo_id)
+    if repo_url is None:
+        raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")
+
+    snapshot = api.snapshot_download(repo_id, revision=None, local_dir=local_dir)
+    if snapshot is None:
+        raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")
+    return snapshot
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..c92b8110073e43000d923779828818dd796c888a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,101 @@
+[project]
+name = "medrag-multi-modal"
+version = "0.0.1"
+description = ""
+readme = "README.md"
+requires-python = ">=3.10"
+dependencies = [
+    "bm25s[full]>=0.2.2",
+    "datasets>=3.1.0",
+    "einops>=0.8.0",
+    "firerequests>=0.0.7",
+    "pdf2image>=1.17.0",
+    "python-dotenv>=1.0.1",
+    "pymupdf4llm>=0.0.17",
+    "weave>=0.51.14",
+    "pip>=24.2",
+    "uv>=0.4.20",
+    "pytest>=8.3.3",
+    "PyPDF2>=3.0.1",
+    "PyStemmer>=2.2.0.3",
+    "safetensors>=0.4.5",
+    "isort>=5.13.2",
+    "black>=24.10.0",
+    "ruff>=0.6.9",
+    "marker-pdf>=0.2.17",
+    "mkdocs>=1.6.1",
+    "mkdocstrings>=0.26.1",
+    "mkdocstrings-python>=1.11.1",
+    "mkdocs-material>=9.5.39",
+    "mkdocs-minify-plugin>=0.8.0",
+    "mkdocs-glightbox>=0.4.0",
+    "mkdocs-jupyter>=0.25.0",
+    "jupyter>=1.1.1",
+    "pdfplumber>=0.11.4",
+    "semchunk>=2.2.0",
+    "tiktoken>=0.8.0",
+    "sentence-transformers>=3.2.0",
+    "google-generativeai>=0.8.3",
+    "mistralai>=1.1.0",
+    "instructor>=1.6.3",
+    "jsonlines>=4.0.0",
+    "opencv-python>=4.10.0.84",
+    "openai>=1.52.2",
+    "streamlit>=1.39.0",
+]
+
+[project.optional-dependencies]
+app = [
+    "streamlit>=1.39.0",
+]
+core = [
+    "bm25s[full]>=0.2.2",
+    "datasets>=3.1.0",
+    "einops>=0.8.0",
+    "firerequests>=0.0.7",
+    "marker-pdf>=0.2.17",
+    "pdf2image>=1.17.0",
+    "pdfplumber>=0.11.4",
+    "PyPDF2>=3.0.1",
+    "PyStemmer>=2.2.0.3",
+    "python-dotenv>=1.0.1",
+    "pymupdf4llm>=0.0.17",
+    "safetensors>=0.4.5",
+    "semchunk>=2.2.0",
+    "tiktoken>=0.8.0",
+    "weave>=0.51.18",
+    "sentence-transformers>=3.2.0",
+    "google-generativeai>=0.8.3",
+    "mistralai>=1.1.0",
+    "instructor>=1.6.3",
+    "jsonlines>=4.0.0",
+    "opencv-python>=4.10.0.84",
+    "openai>=1.52.2",
+]
+dev = [
+    "pytest>=8.3.3",
+    "isort>=5.13.2",
+    "black>=24.10.0",
+    "ruff>=0.6.9",
+]
+docs = [
+    "mkdocs>=1.6.1",
+    "mkdocstrings>=0.26.1",
+    "mkdocstrings-python>=1.11.1",
+    "mkdocs-material>=9.5.39",
+    "mkdocs-minify-plugin>=0.8.0",
+    "mkdocs-glightbox>=0.4.0",
+    "mkdocs-jupyter>=0.25.0",
+    "jupyter>=1.1.1",
+]
+
+[project.scripts]
+medrag = "medrag_multi_modal.cli:main"
+
+[tool.pytest.ini_options]
+pythonpath = "."
+testpaths = ["tests"]
+filterwarnings = "ignore::DeprecationWarning"
+
+[tool.setuptools]
+py-modules = ["medrag_multi_modal"]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d8aa85b060cff542af5b7e00093598f992ee01fd
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,38 @@
+bm25s[full]>=0.2.2
+datasets>=3.1.0
+einops>=0.8.0
+firerequests>=0.0.7
+pdf2image>=1.17.0
+python-dotenv>=1.0.1
+pymupdf4llm>=0.0.17
+weave>=0.51.14
+pip>=24.2
+uv>=0.4.20
+pytest>=8.3.3
+PyPDF2>=3.0.1
+PyStemmer>=2.2.0.3
+safetensors>=0.4.5
+isort>=5.13.2
+black>=24.10.0
+ruff>=0.6.9
+marker-pdf>=0.2.17
+mkdocs>=1.6.1
+mkdocstrings>=0.26.1
+mkdocstrings-python>=1.11.1
+mkdocs-material>=9.5.39
+mkdocs-minify-plugin>=0.8.0
+mkdocs-glightbox>=0.4.0
+mkdocs-jupyter>=0.25.0
+jupyter>=1.1.1
+pdfplumber>=0.11.4
+semchunk>=2.2.0
+tiktoken>=0.8.0
+sentence-transformers>=3.2.0
+google-generativeai>=0.8.3
+mistralai>=1.1.0
+instructor>=1.6.3
+jsonlines>=4.0.0
+opencv-python>=4.10.0.84
+openai>=1.52.2
+streamlit>=1.39.0
+torch --index-url https://download.pytorch.org/whl/cpu
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e29c897e3a84397f83716b12e60cc33e019e523
--- /dev/null
+++ b/test.py
@@ -0,0 +1,12 @@
+import weave
+from datasets import load_dataset
+
+weave.init("ml-colabs/medrag-multi-modal")
+rows = load_dataset("cais/mmlu", "anatomy", split="test").to_list()
+for idx, row in enumerate(rows):
+    rows[idx] = {
+        "query": row["question"],
+        "options": row["choices"],
+        "answer": row["answer"],
+    }
+weave.publish(weave.Dataset(rows=rows, name="mmlu-anatomy-test"))