Spaces:
Running
Running
update libraries and code to support Modal as OpenAI-based server
Browse files- document_qa/document_qa_engine.py +23 -21
- document_qa/langchain.py +0 -1
- requirements.txt +22 -21
- streamlit_app.py +40 -118
document_qa/document_qa_engine.py
CHANGED
|
@@ -5,7 +5,8 @@ from typing import Union, Any, List
|
|
| 5 |
|
| 6 |
import tiktoken
|
| 7 |
from langchain.chains import create_extraction_chain
|
| 8 |
-
from langchain.chains.
|
|
|
|
| 9 |
map_rerank_prompt
|
| 10 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 11 |
from langchain.retrievers import MultiQueryRetriever
|
|
@@ -14,7 +15,6 @@ from langchain_community.vectorstores.chroma import Chroma
|
|
| 14 |
from langchain_core.vectorstores import VectorStore
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
| 17 |
-
# from document_qa.embedding_visualiser import QueryVisualiser
|
| 18 |
from document_qa.grobid_processors import GrobidProcessor
|
| 19 |
from document_qa.langchain import ChromaAdvancedRetrieval
|
| 20 |
|
|
@@ -177,17 +177,19 @@ class DataStorage:
|
|
| 177 |
|
| 178 |
def embed_document(self, doc_id, texts, metadatas):
|
| 179 |
if doc_id not in self.embeddings_dict.keys():
|
| 180 |
-
self.embeddings_dict[doc_id] = self.engine.from_texts(
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
| 184 |
else:
|
| 185 |
# Workaround Chroma (?) breaking change
|
| 186 |
self.embeddings_dict[doc_id].delete_collection()
|
| 187 |
-
self.embeddings_dict[doc_id] = self.engine.from_texts(
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
self.embeddings_root_path = None
|
| 193 |
|
|
@@ -206,14 +208,13 @@ class DocumentQAEngine:
|
|
| 206 |
def __init__(self,
|
| 207 |
llm,
|
| 208 |
data_storage: DataStorage,
|
| 209 |
-
qa_chain_type="stuff",
|
| 210 |
grobid_url=None,
|
| 211 |
memory=None
|
| 212 |
):
|
| 213 |
|
| 214 |
self.llm = llm
|
| 215 |
self.memory = memory
|
| 216 |
-
self.chain =
|
| 217 |
self.text_merger = TextMerger()
|
| 218 |
self.data_storage = data_storage
|
| 219 |
|
|
@@ -271,7 +272,10 @@ class DocumentQAEngine:
|
|
| 271 |
Returns both the context and the embedding information from a given query
|
| 272 |
"""
|
| 273 |
db = self.data_storage.embeddings_dict[doc_id]
|
| 274 |
-
retriever = db.as_retriever(
|
|
|
|
|
|
|
|
|
|
| 275 |
relevant_documents = retriever.invoke(query)
|
| 276 |
|
| 277 |
return relevant_documents
|
|
@@ -327,20 +331,18 @@ class DocumentQAEngine:
|
|
| 327 |
|
| 328 |
def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
|
| 329 |
relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
|
| 330 |
-
response = self.chain.
|
| 331 |
-
question=query)
|
| 332 |
-
|
| 333 |
-
if self.memory:
|
| 334 |
-
self.memory.save_context({"input": query}, {"output": response})
|
| 335 |
return response, relevant_document_coordinates
|
| 336 |
|
| 337 |
def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
|
| 338 |
db = self.data_storage.embeddings_dict[doc_id]
|
| 339 |
retriever = db.as_retriever(search_kwargs={"k": context_size})
|
| 340 |
relevant_documents = retriever.invoke(query)
|
| 341 |
-
relevant_document_coordinates = [
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
| 344 |
if self.memory and len(self.memory.buffer_as_messages) > 0:
|
| 345 |
relevant_documents.append(
|
| 346 |
Document(
|
|
|
|
| 5 |
|
| 6 |
import tiktoken
|
| 7 |
from langchain.chains import create_extraction_chain
|
| 8 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 9 |
+
from langchain.chains.question_answering import stuff_prompt, refine_prompts, map_reduce_prompt, \
|
| 10 |
map_rerank_prompt
|
| 11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
| 12 |
from langchain.retrievers import MultiQueryRetriever
|
|
|
|
| 15 |
from langchain_core.vectorstores import VectorStore
|
| 16 |
from tqdm import tqdm
|
| 17 |
|
|
|
|
| 18 |
from document_qa.grobid_processors import GrobidProcessor
|
| 19 |
from document_qa.langchain import ChromaAdvancedRetrieval
|
| 20 |
|
|
|
|
| 177 |
|
| 178 |
def embed_document(self, doc_id, texts, metadatas):
|
| 179 |
if doc_id not in self.embeddings_dict.keys():
|
| 180 |
+
self.embeddings_dict[doc_id] = self.engine.from_texts(
|
| 181 |
+
texts,
|
| 182 |
+
embedding=self.embedding_function,
|
| 183 |
+
metadatas=metadatas,
|
| 184 |
+
collection_name=doc_id)
|
| 185 |
else:
|
| 186 |
# Workaround Chroma (?) breaking change
|
| 187 |
self.embeddings_dict[doc_id].delete_collection()
|
| 188 |
+
self.embeddings_dict[doc_id] = self.engine.from_texts(
|
| 189 |
+
texts,
|
| 190 |
+
embedding=self.embedding_function,
|
| 191 |
+
metadatas=metadatas,
|
| 192 |
+
collection_name=doc_id)
|
| 193 |
|
| 194 |
self.embeddings_root_path = None
|
| 195 |
|
|
|
|
| 208 |
def __init__(self,
|
| 209 |
llm,
|
| 210 |
data_storage: DataStorage,
|
|
|
|
| 211 |
grobid_url=None,
|
| 212 |
memory=None
|
| 213 |
):
|
| 214 |
|
| 215 |
self.llm = llm
|
| 216 |
self.memory = memory
|
| 217 |
+
self.chain = create_stuff_documents_chain(llm, self.default_prompts['stuff'].PROMPT)
|
| 218 |
self.text_merger = TextMerger()
|
| 219 |
self.data_storage = data_storage
|
| 220 |
|
|
|
|
| 272 |
Returns both the context and the embedding information from a given query
|
| 273 |
"""
|
| 274 |
db = self.data_storage.embeddings_dict[doc_id]
|
| 275 |
+
retriever = db.as_retriever(
|
| 276 |
+
search_kwargs={"k": context_size},
|
| 277 |
+
search_type="similarity_with_embeddings"
|
| 278 |
+
)
|
| 279 |
relevant_documents = retriever.invoke(query)
|
| 280 |
|
| 281 |
return relevant_documents
|
|
|
|
| 331 |
|
| 332 |
def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
|
| 333 |
relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
|
| 334 |
+
response = self.chain.invoke({"context": relevant_documents, "question": query})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
return response, relevant_document_coordinates
|
| 336 |
|
| 337 |
def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
|
| 338 |
db = self.data_storage.embeddings_dict[doc_id]
|
| 339 |
retriever = db.as_retriever(search_kwargs={"k": context_size})
|
| 340 |
relevant_documents = retriever.invoke(query)
|
| 341 |
+
relevant_document_coordinates = [
|
| 342 |
+
doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
| 343 |
+
for doc in
|
| 344 |
+
relevant_documents
|
| 345 |
+
]
|
| 346 |
if self.memory and len(self.memory.buffer_as_messages) > 0:
|
| 347 |
relevant_documents.append(
|
| 348 |
Document(
|
document_qa/langchain.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
|
| 3 |
|
| 4 |
from langchain.schema import Document
|
|
|
|
|
|
|
| 1 |
from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
|
| 2 |
|
| 3 |
from langchain.schema import Document
|
requirements.txt
CHANGED
|
@@ -1,32 +1,33 @@
|
|
| 1 |
# Grobid
|
| 2 |
grobid-quantities-client==0.4.0
|
| 3 |
grobid-client-python==0.0.9
|
| 4 |
-
|
| 5 |
|
| 6 |
# Utils
|
| 7 |
-
tqdm==4.66.
|
| 8 |
pyyaml==6.0.1
|
| 9 |
pytest==8.1.1
|
| 10 |
-
streamlit==1.
|
| 11 |
-
lxml
|
| 12 |
-
|
| 13 |
-
python-dotenv
|
| 14 |
-
watchdog
|
| 15 |
-
dateparser
|
|
|
|
| 16 |
|
| 17 |
# LLM
|
| 18 |
chromadb==0.4.24
|
| 19 |
-
tiktoken==0.
|
| 20 |
-
openai==1.
|
| 21 |
-
langchain==0.
|
| 22 |
-
langchain-core==0.
|
| 23 |
-
langchain-openai==0.
|
| 24 |
-
langchain-huggingface==0.0
|
| 25 |
-
langchain-community==0.
|
| 26 |
typing-inspect==0.9.0
|
| 27 |
-
typing_extensions==4.
|
| 28 |
-
pydantic==2.6
|
| 29 |
-
|
| 30 |
-
streamlit-pdf-viewer==0.0.
|
| 31 |
-
umap-learn
|
| 32 |
-
plotly
|
|
|
|
| 1 |
# Grobid
|
| 2 |
grobid-quantities-client==0.4.0
|
| 3 |
grobid-client-python==0.0.9
|
| 4 |
+
grobid-tei-xml==0.1.3
|
| 5 |
|
| 6 |
# Utils
|
| 7 |
+
tqdm==4.66.3
|
| 8 |
pyyaml==6.0.1
|
| 9 |
pytest==8.1.1
|
| 10 |
+
streamlit==1.45.1
|
| 11 |
+
lxml==5.2.1
|
| 12 |
+
beautifulsoup4==4.12.3
|
| 13 |
+
python-dotenv==1.0.1
|
| 14 |
+
watchdog==4.0.0
|
| 15 |
+
dateparser==1.2.0
|
| 16 |
+
requests>=2.31.0
|
| 17 |
|
| 18 |
# LLM
|
| 19 |
chromadb==0.4.24
|
| 20 |
+
tiktoken==0.9.0
|
| 21 |
+
openai==1.82.0
|
| 22 |
+
langchain==0.3.25
|
| 23 |
+
langchain-core==0.3.61
|
| 24 |
+
langchain-openai==0.3.18
|
| 25 |
+
langchain-huggingface==0.2.0
|
| 26 |
+
langchain-community==0.3.21
|
| 27 |
typing-inspect==0.9.0
|
| 28 |
+
typing_extensions==4.12.2
|
| 29 |
+
pydantic==2.10.6
|
| 30 |
+
sentence-transformers==2.6.1
|
| 31 |
+
streamlit-pdf-viewer==0.0.22rc0
|
| 32 |
+
umap-learn==0.5.6
|
| 33 |
+
plotly==5.20.0
|
streamlit_app.py
CHANGED
|
@@ -5,11 +5,9 @@ from tempfile import NamedTemporaryFile
|
|
| 5 |
|
| 6 |
import dotenv
|
| 7 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 8 |
-
from langchain.memory import
|
| 9 |
-
from langchain_community.chat_models import ChatOpenAI
|
| 10 |
-
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
| 11 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
-
from langchain_openai import
|
| 13 |
from streamlit_pdf_viewer import pdf_viewer
|
| 14 |
|
| 15 |
from document_qa.ner_client_generic import NERClientGeneric
|
|
@@ -20,30 +18,14 @@ import streamlit as st
|
|
| 20 |
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
|
| 21 |
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
"gpt-4-1106-preview"]
|
| 26 |
-
|
| 27 |
-
OPENAI_EMBEDDINGS = [
|
| 28 |
-
'text-embedding-ada-002',
|
| 29 |
-
'text-embedding-3-large',
|
| 30 |
-
'openai-text-embedding-3-small'
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
OPEN_MODELS = {
|
| 34 |
-
'Mistral-Nemo-Instruct-2407': 'mistralai/Mistral-Nemo-Instruct-2407',
|
| 35 |
-
'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.3',
|
| 36 |
-
'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct"
|
| 37 |
}
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
'SFR-Embedding-
|
| 43 |
-
'SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R',
|
| 44 |
-
'NV-Embed': 'nvidia/NV-Embed-v1',
|
| 45 |
-
'e5-mistral-7b-instruct': 'intfloat/e5-mistral-7b-instruct',
|
| 46 |
-
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5'
|
| 47 |
}
|
| 48 |
|
| 49 |
if 'rqa' not in st.session_state:
|
|
@@ -141,48 +123,20 @@ def clear_memory():
|
|
| 141 |
|
| 142 |
|
| 143 |
# @st.cache_resource
|
| 144 |
-
def init_qa(
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
frequency_penalty=0.1)
|
| 156 |
-
if embeddings_name not in OPENAI_EMBEDDINGS:
|
| 157 |
-
st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.")
|
| 158 |
-
st.stop()
|
| 159 |
-
return
|
| 160 |
-
embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key)
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
temperature=0,
|
| 165 |
-
frequency_penalty=0.1)
|
| 166 |
-
embeddings = OpenAIEmbeddings(model=embeddings_name)
|
| 167 |
-
|
| 168 |
-
elif model in OPEN_MODELS:
|
| 169 |
-
if embeddings_name is None:
|
| 170 |
-
embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME
|
| 171 |
-
|
| 172 |
-
chat = HuggingFaceEndpoint(
|
| 173 |
-
repo_id=OPEN_MODELS[model],
|
| 174 |
-
temperature=0.01,
|
| 175 |
-
max_new_tokens=4092,
|
| 176 |
-
model_kwargs={"max_length": 8192},
|
| 177 |
-
# callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])]
|
| 178 |
-
)
|
| 179 |
-
embeddings = HuggingFaceEmbeddings(
|
| 180 |
-
model_name=OPEN_EMBEDDINGS[embeddings_name])
|
| 181 |
-
# st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
|
| 182 |
-
else:
|
| 183 |
-
st.error("The model was not loaded properly. Try reloading. ")
|
| 184 |
-
st.stop()
|
| 185 |
-
return
|
| 186 |
|
| 187 |
storage = DataStorage(embeddings)
|
| 188 |
return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
|
|
@@ -246,65 +200,31 @@ with st.sidebar:
|
|
| 246 |
st.divider()
|
| 247 |
st.session_state['model'] = model = st.selectbox(
|
| 248 |
"Model:",
|
| 249 |
-
options=
|
| 250 |
-
index=(
|
| 251 |
os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0,
|
| 252 |
placeholder="Select model",
|
| 253 |
help="Select the LLM model:",
|
| 254 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
| 255 |
)
|
| 256 |
-
embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS
|
| 257 |
|
| 258 |
st.session_state['embeddings'] = embedding_name = st.selectbox(
|
| 259 |
"Embeddings:",
|
| 260 |
-
options=
|
| 261 |
-
index=
|
|
|
|
|
|
|
| 262 |
placeholder="Select embedding",
|
| 263 |
help="Select the Embedding function:",
|
| 264 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
| 265 |
)
|
| 266 |
|
| 267 |
-
|
| 268 |
-
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
| 269 |
-
api_key = st.text_input('Huggingface API Key', type="password")
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
if api_key:
|
| 276 |
-
# st.session_state['api_key'] = is_api_key_provided = True
|
| 277 |
-
if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
|
| 278 |
-
with st.spinner("Preparing environment"):
|
| 279 |
-
st.session_state['api_keys'][model] = api_key
|
| 280 |
-
# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
| 281 |
-
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 282 |
-
st.session_state['rqa'][model] = init_qa(model, embedding_name)
|
| 283 |
-
|
| 284 |
-
elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
|
| 285 |
-
if 'OPENAI_API_KEY' not in os.environ:
|
| 286 |
-
api_key = st.text_input('OpenAI API Key', type="password")
|
| 287 |
-
st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
|
| 288 |
-
else:
|
| 289 |
-
api_key = os.environ['OPENAI_API_KEY']
|
| 290 |
-
|
| 291 |
-
if api_key:
|
| 292 |
-
if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
|
| 293 |
-
with st.spinner("Preparing environment"):
|
| 294 |
-
st.session_state['api_keys'][model] = api_key
|
| 295 |
-
if 'OPENAI_API_KEY' not in os.environ:
|
| 296 |
-
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key)
|
| 297 |
-
else:
|
| 298 |
-
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
|
| 299 |
-
# else:
|
| 300 |
-
# is_api_key_provided = st.session_state['api_key']
|
| 301 |
-
|
| 302 |
-
# st.button(
|
| 303 |
-
# 'Reset chat memory.',
|
| 304 |
-
# key="reset-memory-button",
|
| 305 |
-
# on_click=clear_memory,
|
| 306 |
-
# help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
|
| 307 |
-
# disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
|
| 308 |
|
| 309 |
left_column, right_column = st.columns([5, 4])
|
| 310 |
right_column = right_column.container(border=True)
|
|
@@ -390,15 +310,16 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
| 390 |
st.stop()
|
| 391 |
|
| 392 |
with left_column:
|
| 393 |
-
with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
|
| 394 |
binary = uploaded_file.getvalue()
|
| 395 |
tmp_file = NamedTemporaryFile()
|
| 396 |
tmp_file.write(bytearray(binary))
|
| 397 |
st.session_state['binary'] = binary
|
| 398 |
|
| 399 |
-
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(
|
| 400 |
-
|
| 401 |
-
|
|
|
|
| 402 |
st.session_state['loaded_embeddings'] = True
|
| 403 |
st.session_state.messages = []
|
| 404 |
|
|
@@ -477,7 +398,7 @@ with right_column:
|
|
| 477 |
annotation_doc]
|
| 478 |
|
| 479 |
if not text_response:
|
| 480 |
-
st.error("Something went wrong. Contact
|
| 481 |
|
| 482 |
if mode == "llm":
|
| 483 |
if st.session_state['ner_processing']:
|
|
@@ -503,5 +424,6 @@ with left_column:
|
|
| 503 |
annotation_outline_size=2,
|
| 504 |
annotations=st.session_state['annotations'] if st.session_state['annotations'] else [],
|
| 505 |
render_text=True,
|
| 506 |
-
scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state[
|
|
|
|
| 507 |
)
|
|
|
|
| 5 |
|
| 6 |
import dotenv
|
| 7 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 8 |
+
from langchain.memory import ConversationBufferMemory
|
|
|
|
|
|
|
| 9 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
from streamlit_pdf_viewer import pdf_viewer
|
| 12 |
|
| 13 |
from document_qa.ner_client_generic import NERClientGeneric
|
|
|
|
| 18 |
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
|
| 19 |
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
|
| 20 |
|
| 21 |
+
API_MODELS = {
|
| 22 |
+
"microsoft/Phi-4-mini-instruct": os.environ["MODAL_1_URL"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
+
API_EMBEDDINGS = {
|
| 26 |
+
'intfloat/e5-large-v2': 'intfloat/e5-large-v2',
|
| 27 |
+
'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct:',
|
| 28 |
+
'Salesforce/SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
if 'rqa' not in st.session_state:
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
# @st.cache_resource
|
| 126 |
+
def init_qa(model_name, embeddings_name):
|
| 127 |
+
st.session_state['memory'] = ConversationBufferMemory(
|
| 128 |
+
memory_key="chat_history",
|
| 129 |
+
return_messages=True
|
| 130 |
+
)
|
| 131 |
+
chat = ChatOpenAI(
|
| 132 |
+
model=model_name,
|
| 133 |
+
temperature=0.0,
|
| 134 |
+
base_url=API_MODELS[model_name],
|
| 135 |
+
api_key=os.environ.get('API_KEY')
|
| 136 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
embeddings = HuggingFaceEmbeddings(
|
| 139 |
+
model_name=API_EMBEDDINGS[embeddings_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
storage = DataStorage(embeddings)
|
| 142 |
return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
|
|
|
|
| 200 |
st.divider()
|
| 201 |
st.session_state['model'] = model = st.selectbox(
|
| 202 |
"Model:",
|
| 203 |
+
options=API_MODELS.keys(),
|
| 204 |
+
index=(list(API_MODELS.keys())).index(
|
| 205 |
os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0,
|
| 206 |
placeholder="Select model",
|
| 207 |
help="Select the LLM model:",
|
| 208 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
| 209 |
)
|
|
|
|
| 210 |
|
| 211 |
st.session_state['embeddings'] = embedding_name = st.selectbox(
|
| 212 |
"Embeddings:",
|
| 213 |
+
options=API_EMBEDDINGS.keys(),
|
| 214 |
+
index=(list(API_EMBEDDINGS.keys())).index(
|
| 215 |
+
os.environ["DEFAULT_EMBEDDING"]) if "DEFAULT_EMBEDDING" in os.environ and os.environ[
|
| 216 |
+
"DEFAULT_EMBEDDING"] else 0,
|
| 217 |
placeholder="Select embedding",
|
| 218 |
help="Select the Embedding function:",
|
| 219 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
| 220 |
)
|
| 221 |
|
| 222 |
+
api_key = os.environ['API_KEY']
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
|
| 225 |
+
with st.spinner("Preparing environment"):
|
| 226 |
+
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
|
| 227 |
+
st.session_state['api_keys'][model] = api_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
left_column, right_column = st.columns([5, 4])
|
| 230 |
right_column = right_column.container(border=True)
|
|
|
|
| 310 |
st.stop()
|
| 311 |
|
| 312 |
with left_column:
|
| 313 |
+
with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'):
|
| 314 |
binary = uploaded_file.getvalue()
|
| 315 |
tmp_file = NamedTemporaryFile()
|
| 316 |
tmp_file.write(bytearray(binary))
|
| 317 |
st.session_state['binary'] = binary
|
| 318 |
|
| 319 |
+
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(
|
| 320 |
+
tmp_file.name,
|
| 321 |
+
chunk_size=chunk_size,
|
| 322 |
+
perc_overlap=0.1)
|
| 323 |
st.session_state['loaded_embeddings'] = True
|
| 324 |
st.session_state.messages = []
|
| 325 |
|
|
|
|
| 398 |
annotation_doc]
|
| 399 |
|
| 400 |
if not text_response:
|
| 401 |
+
st.error("Something went wrong. Contact info AT sciencialab.com to report the issue through GitHub.")
|
| 402 |
|
| 403 |
if mode == "llm":
|
| 404 |
if st.session_state['ner_processing']:
|
|
|
|
| 424 |
annotation_outline_size=2,
|
| 425 |
annotations=st.session_state['annotations'] if st.session_state['annotations'] else [],
|
| 426 |
render_text=True,
|
| 427 |
+
scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state[
|
| 428 |
+
'scroll_to_first_annotation']) else None
|
| 429 |
)
|