Spaces:
Runtime error
Runtime error
Zwea Htet
commited on
Commit
·
ab0abe2
1
Parent(s):
e262fe5
updated llm configs and prompts
Browse files- models/llamaCustom.py +3 -8
- models/llamaCustomV2.py +229 -0
- models/llms.py +2 -0
- pages/llama_custom_demo.py +2 -0
- requirements.txt +2 -1
models/llamaCustom.py
CHANGED
@@ -18,11 +18,7 @@ from assets.prompts import custom_prompts
|
|
18 |
|
19 |
# llama index
|
20 |
from llama_index.core import (
|
21 |
-
StorageContext,
|
22 |
-
SimpleDirectoryReader,
|
23 |
VectorStoreIndex,
|
24 |
-
load_index_from_storage,
|
25 |
-
PromptHelper,
|
26 |
PromptTemplate,
|
27 |
)
|
28 |
from llama_index.core.llms import (
|
@@ -47,12 +43,10 @@ NUM_OUTPUT = 525
|
|
47 |
# set maximum chunk overlap
|
48 |
CHUNK_OVERLAP_RATION = 0.2
|
49 |
|
50 |
-
# TODO: use the following prompt to format the answer at the end of the context prompt
|
51 |
ANSWER_FORMAT = """
|
52 |
-
|
53 |
[FORMAT]
|
54 |
-
|
55 |
-
The answer to the user question.
|
56 |
Reference:
|
57 |
The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
|
58 |
[END_FORMAT]
|
@@ -200,6 +194,7 @@ class LlamaCustom:
|
|
200 |
# condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
|
201 |
# # verbose=True,
|
202 |
# )
|
|
|
203 |
response = query_engine.query(query_str)
|
204 |
# response = chat_engine.chat(message=query_str, chat_history=chat_history)
|
205 |
|
|
|
18 |
|
19 |
# llama index
|
20 |
from llama_index.core import (
|
|
|
|
|
21 |
VectorStoreIndex,
|
|
|
|
|
22 |
PromptTemplate,
|
23 |
)
|
24 |
from llama_index.core.llms import (
|
|
|
43 |
# set maximum chunk overlap
|
44 |
CHUNK_OVERLAP_RATION = 0.2
|
45 |
|
|
|
46 |
ANSWER_FORMAT = """
|
47 |
+
Provide the answer to the user question in the following format:
|
48 |
[FORMAT]
|
49 |
+
Your answer to the user question above.
|
|
|
50 |
Reference:
|
51 |
The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
|
52 |
[END_FORMAT]
|
|
|
194 |
# condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
|
195 |
# # verbose=True,
|
196 |
# )
|
197 |
+
|
198 |
response = query_engine.query(query_str)
|
199 |
# response = chat_engine.chat(message=query_str, chat_history=chat_history)
|
200 |
|
models/llamaCustomV2.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from llama_index.core import VectorStoreIndex
|
4 |
+
from llama_index.core.query_pipeline import (
|
5 |
+
QueryPipeline,
|
6 |
+
InputComponent,
|
7 |
+
ArgPackComponent,
|
8 |
+
)
|
9 |
+
from llama_index.core.prompts import PromptTemplate
|
10 |
+
from llama_index.llms.openai import OpenAI
|
11 |
+
from llama_index.postprocessor.colbert_rerank import ColbertRerank
|
12 |
+
from typing import Any, Dict, List, Optional
|
13 |
+
from llama_index.core.bridge.pydantic import Field
|
14 |
+
from llama_index.core.llms import ChatMessage
|
15 |
+
from llama_index.core.query_pipeline import CustomQueryComponent
|
16 |
+
from llama_index.core.schema import NodeWithScore
|
17 |
+
from llama_index.core.memory import ChatMemoryBuffer
|
18 |
+
|
19 |
+
|
20 |
+
llm = OpenAI(
|
21 |
+
model="gpt-3.5-turbo-0125",
|
22 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
23 |
+
)
|
24 |
+
|
25 |
+
# First, we create an input component to capture the user query
|
26 |
+
input_component = InputComponent()
|
27 |
+
|
28 |
+
# Next, we use the LLM to rewrite a user query
|
29 |
+
rewrite = (
|
30 |
+
"Please write a query to a semantic search engine using the current conversation.\n"
|
31 |
+
"\n"
|
32 |
+
"\n"
|
33 |
+
"{chat_history_str}"
|
34 |
+
"\n"
|
35 |
+
"\n"
|
36 |
+
"Latest message: {query_str}\n"
|
37 |
+
'Query:"""\n'
|
38 |
+
)
|
39 |
+
rewrite_template = PromptTemplate(rewrite)
|
40 |
+
|
41 |
+
# we will retrieve two times, so we need to pack the retrieved nodes into a single list
|
42 |
+
argpack_component = ArgPackComponent()
|
43 |
+
|
44 |
+
# then postprocess/rerank with Colbert
|
45 |
+
reranker = ColbertRerank(top_n=3)
|
46 |
+
|
47 |
+
DEFAULT_CONTEXT_PROMPT = (
|
48 |
+
"Here is some context that may be relevant:\n"
|
49 |
+
"-----\n"
|
50 |
+
"{node_context}\n"
|
51 |
+
"-----\n"
|
52 |
+
"Please write a response to the following question, using the above context:\n"
|
53 |
+
"{query_str}\n"
|
54 |
+
"Please formate your response in the following way:\n"
|
55 |
+
"Your answer here.\n"
|
56 |
+
"Reference:\n"
|
57 |
+
" Your references here (e.g. page numbers, titles, etc.).\n"
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
class ResponseWithChatHistory(CustomQueryComponent):
|
62 |
+
llm: OpenAI = Field(..., description="OpenAI LLM")
|
63 |
+
system_prompt: Optional[str] = Field(
|
64 |
+
default=None, description="System prompt to use for the LLM"
|
65 |
+
)
|
66 |
+
context_prompt: str = Field(
|
67 |
+
default=DEFAULT_CONTEXT_PROMPT,
|
68 |
+
description="Context prompt to use for the LLM",
|
69 |
+
)
|
70 |
+
|
71 |
+
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
72 |
+
"""Validate component inputs during run_component."""
|
73 |
+
# NOTE: this is OPTIONAL but we show you where to do validation as an example
|
74 |
+
return input
|
75 |
+
|
76 |
+
@property
|
77 |
+
def _input_keys(self) -> set:
|
78 |
+
"""Input keys dict."""
|
79 |
+
# NOTE: These are required inputs. If you have optional inputs please override
|
80 |
+
# `optional_input_keys_dict`
|
81 |
+
return {"chat_history", "nodes", "query_str"}
|
82 |
+
|
83 |
+
@property
|
84 |
+
def _output_keys(self) -> set:
|
85 |
+
return {"response"}
|
86 |
+
|
87 |
+
def _prepare_context(
|
88 |
+
self,
|
89 |
+
chat_history: List[ChatMessage],
|
90 |
+
nodes: List[NodeWithScore],
|
91 |
+
query_str: str,
|
92 |
+
) -> List[ChatMessage]:
|
93 |
+
node_context = ""
|
94 |
+
for idx, node in enumerate(nodes):
|
95 |
+
node_text = node.get_content(metadata_mode="llm")
|
96 |
+
node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
|
97 |
+
|
98 |
+
formatted_context = self.context_prompt.format(
|
99 |
+
node_context=node_context, query_str=query_str
|
100 |
+
)
|
101 |
+
user_message = ChatMessage(role="user", content=formatted_context)
|
102 |
+
|
103 |
+
chat_history.append(user_message)
|
104 |
+
|
105 |
+
if self.system_prompt is not None:
|
106 |
+
chat_history = [
|
107 |
+
ChatMessage(role="system", content=self.system_prompt)
|
108 |
+
] + chat_history
|
109 |
+
|
110 |
+
return chat_history
|
111 |
+
|
112 |
+
def _run_component(self, **kwargs) -> Dict[str, Any]:
|
113 |
+
"""Run the component."""
|
114 |
+
chat_history = kwargs["chat_history"]
|
115 |
+
nodes = kwargs["nodes"]
|
116 |
+
query_str = kwargs["query_str"]
|
117 |
+
|
118 |
+
prepared_context = self._prepare_context(chat_history, nodes, query_str)
|
119 |
+
|
120 |
+
response = llm.chat(prepared_context)
|
121 |
+
|
122 |
+
return {"response": response}
|
123 |
+
|
124 |
+
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
|
125 |
+
"""Run the component asynchronously."""
|
126 |
+
# NOTE: Optional, but async LLM calls are easy to implement
|
127 |
+
chat_history = kwargs["chat_history"]
|
128 |
+
nodes = kwargs["nodes"]
|
129 |
+
query_str = kwargs["query_str"]
|
130 |
+
|
131 |
+
prepared_context = self._prepare_context(chat_history, nodes, query_str)
|
132 |
+
|
133 |
+
response = await llm.achat(prepared_context)
|
134 |
+
|
135 |
+
return {"response": response}
|
136 |
+
|
137 |
+
|
138 |
+
class LlamaCustomV2:
|
139 |
+
response_component = ResponseWithChatHistory(
|
140 |
+
llm=llm,
|
141 |
+
system_prompt=(
|
142 |
+
"You are a Q&A system. You will be provided with the previous chat history, "
|
143 |
+
"as well as possibly relevant context, to assist in answering a user message."
|
144 |
+
),
|
145 |
+
)
|
146 |
+
|
147 |
+
def __init__(self, model_name: str, index: VectorStoreIndex):
|
148 |
+
self.model_name = model_name
|
149 |
+
self.index = index
|
150 |
+
self.retriever = index.as_retriever()
|
151 |
+
self.chat_mode = "condense_plus_context"
|
152 |
+
self.memory = ChatMemoryBuffer.from_defaults()
|
153 |
+
self.verbose = True
|
154 |
+
self._build_pipeline()
|
155 |
+
|
156 |
+
def _build_pipeline(self):
|
157 |
+
self.pipeline = QueryPipeline(
|
158 |
+
modules={
|
159 |
+
"input": input_component,
|
160 |
+
"rewrite_template": rewrite_template,
|
161 |
+
"llm": llm,
|
162 |
+
"rewrite_retriever": self.retriever,
|
163 |
+
"query_retriever": self.retriever,
|
164 |
+
"join": argpack_component,
|
165 |
+
"reranker": reranker,
|
166 |
+
"response_component": self.response_component,
|
167 |
+
},
|
168 |
+
verbose=self.verbose,
|
169 |
+
)
|
170 |
+
# run both retrievers -- once with the hallucinated query, once with the real query
|
171 |
+
self.pipeline.add_link(
|
172 |
+
"input", "rewrite_template", src_key="query_str", dest_key="query_str"
|
173 |
+
)
|
174 |
+
self.pipeline.add_link(
|
175 |
+
"input",
|
176 |
+
"rewrite_template",
|
177 |
+
src_key="chat_history_str",
|
178 |
+
dest_key="chat_history_str",
|
179 |
+
)
|
180 |
+
self.pipeline.add_link("rewrite_template", "llm")
|
181 |
+
self.pipeline.add_link("llm", "rewrite_retriever")
|
182 |
+
self.pipeline.add_link("input", "query_retriever", src_key="query_str")
|
183 |
+
|
184 |
+
# each input to the argpack component needs a dest key -- it can be anything
|
185 |
+
# then, the argpack component will pack all the inputs into a single list
|
186 |
+
self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
|
187 |
+
self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
|
188 |
+
|
189 |
+
# reranker needs the packed nodes and the query string
|
190 |
+
self.pipeline.add_link("join", "reranker", dest_key="nodes")
|
191 |
+
self.pipeline.add_link(
|
192 |
+
"input", "reranker", src_key="query_str", dest_key="query_str"
|
193 |
+
)
|
194 |
+
|
195 |
+
# synthesizer needs the reranked nodes and query str
|
196 |
+
self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
|
197 |
+
self.pipeline.add_link(
|
198 |
+
"input", "response_component", src_key="query_str", dest_key="query_str"
|
199 |
+
)
|
200 |
+
self.pipeline.add_link(
|
201 |
+
"input",
|
202 |
+
"response_component",
|
203 |
+
src_key="chat_history",
|
204 |
+
dest_key="chat_history",
|
205 |
+
)
|
206 |
+
|
207 |
+
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
|
208 |
+
chat_history = self.memory.get()
|
209 |
+
char_history_str = "\n".join([str(x) for x in chat_history])
|
210 |
+
|
211 |
+
response = self.pipeline.run(
|
212 |
+
query_str=query_str,
|
213 |
+
chat_history=chat_history,
|
214 |
+
chat_history_str=char_history_str,
|
215 |
+
)
|
216 |
+
|
217 |
+
user_msg = ChatMessage(role="user", content=query_str)
|
218 |
+
print("user_msg: ", str(user_msg))
|
219 |
+
print("response: ", str(response.message))
|
220 |
+
self.memory.put(user_msg)
|
221 |
+
self.memory.put(response.message)
|
222 |
+
|
223 |
+
return str(response.message)
|
224 |
+
|
225 |
+
def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
|
226 |
+
response = self.get_response(query_str=query_str, chat_history=chat_history)
|
227 |
+
for word in response.split():
|
228 |
+
yield word + " "
|
229 |
+
time.sleep(0.05)
|
models/llms.py
CHANGED
@@ -35,6 +35,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
|
|
35 |
llm_gpt_3_5_turbo_0125 = OpenAI(
|
36 |
model=model_name,
|
37 |
api_key=st.session_state.openai_api_key,
|
|
|
38 |
)
|
39 |
|
40 |
return llm_gpt_3_5_turbo_0125
|
@@ -45,6 +46,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
|
|
45 |
is_chat_model=True,
|
46 |
additional_kwargs={"max_new_tokens": 250},
|
47 |
prompt_key=st.session_state.replicate_api_token,
|
|
|
48 |
)
|
49 |
|
50 |
return llm_llama_13b_v2_replicate
|
|
|
35 |
llm_gpt_3_5_turbo_0125 = OpenAI(
|
36 |
model=model_name,
|
37 |
api_key=st.session_state.openai_api_key,
|
38 |
+
temperature=0.0,
|
39 |
)
|
40 |
|
41 |
return llm_gpt_3_5_turbo_0125
|
|
|
46 |
is_chat_model=True,
|
47 |
additional_kwargs={"max_new_tokens": 250},
|
48 |
prompt_key=st.session_state.replicate_api_token,
|
49 |
+
temperature=0.0,
|
50 |
)
|
51 |
|
52 |
return llm_llama_13b_v2_replicate
|
pages/llama_custom_demo.py
CHANGED
@@ -7,6 +7,7 @@ from typing import List
|
|
7 |
from models.llms import load_llm, integrated_llms
|
8 |
from models.embeddings import hf_embed_model, openai_embed_model
|
9 |
from models.llamaCustom import LlamaCustom
|
|
|
10 |
|
11 |
# from models.vector_database import pinecone_vector_store
|
12 |
from utils.chatbox import show_previous_messages, show_chat_input
|
@@ -209,6 +210,7 @@ with tab1:
|
|
209 |
|
210 |
st.write("Finishing Up ...")
|
211 |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
|
|
|
212 |
st.session_state.llama_custom = llama_custom
|
213 |
|
214 |
status.update(label="Ready to query!", state="complete", expanded=False)
|
|
|
7 |
from models.llms import load_llm, integrated_llms
|
8 |
from models.embeddings import hf_embed_model, openai_embed_model
|
9 |
from models.llamaCustom import LlamaCustom
|
10 |
+
from models.llamaCustomV2 import LlamaCustomV2
|
11 |
|
12 |
# from models.vector_database import pinecone_vector_store
|
13 |
from utils.chatbox import show_previous_messages, show_chat_input
|
|
|
210 |
|
211 |
st.write("Finishing Up ...")
|
212 |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
|
213 |
+
# llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index)
|
214 |
st.session_state.llama_custom = llama_custom
|
215 |
|
216 |
status.update(label="Ready to query!", state="complete", expanded=False)
|
requirements.txt
CHANGED
@@ -16,4 +16,5 @@ llama-index-vector-stores-pinecone
|
|
16 |
pinecone-client>=3.0.0
|
17 |
replicate>=0.25.1
|
18 |
llama-index-llms-replicate
|
19 |
-
sentence-transformers>=2.6.1
|
|
|
|
16 |
pinecone-client>=3.0.0
|
17 |
replicate>=0.25.1
|
18 |
llama-index-llms-replicate
|
19 |
+
sentence-transformers>=2.6.1
|
20 |
+
llama-index-postprocessor-colbert-rerank
|