RegBot4.0 / models /langOpen.py
hbui's picture
Update models/langOpen.py
92a6b2a verified
import os
import openai
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain_pinecone import PineconeVectorStore
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain_community.llms import Cohere
prompt_template = """You are an expert on California Drinking Water Regulations.
Answer the question solely using relevant regulations in the given context. DO NOT USE ANY OTHER SOURCES.
If the given context does not contain the relevant information, say so.
Context: {context}
Topic: {topic}
Use the following example format for your answer:
Answer:
The answer to the user question.
Reference:
The list of references to the specific sections of the documents that support your answer.
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "topic"])
class LangOpen:
def __init__(self, model_name: str) -> None:
self.index = self.initialize_index("langOpen")
self.llm = ChatOpenAI(temperature=0.01, model=model_name)
self.chain = LLMChain(llm=self.llm, prompt=PROMPT)
def initialize_index(self, index_name):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
index_name = "llamaparse-md3-openai-embeddings"
vectorstore = PineconeVectorStore(index_name=index_name, embedding=embeddings)
return vectorstore
def get_response(self, query_str):
print("query_str: ", query_str)
print("model_name: ", self.llm.model_name)
#docs = self.index.similarity_search(query_str, k=4)
vectorstore_retriever = self.index.as_retriever(search_type="similarity", search_kwargs={"k": 10})
compressor = CohereRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=vectorstore_retriever
)
docs = compression_retriever.get_relevant_documents(query_str)
inputs = [{"context": doc.page_content, "topic": query_str} for doc in docs]
result = self.chain.apply(inputs)[0]["text"]
return result