Zwea Htet commited on
Commit
ab0abe2
·
1 Parent(s): e262fe5

updated llm configs and prompts

Browse files
models/llamaCustom.py CHANGED
@@ -18,11 +18,7 @@ from assets.prompts import custom_prompts
18
 
19
  # llama index
20
  from llama_index.core import (
21
- StorageContext,
22
- SimpleDirectoryReader,
23
  VectorStoreIndex,
24
- load_index_from_storage,
25
- PromptHelper,
26
  PromptTemplate,
27
  )
28
  from llama_index.core.llms import (
@@ -47,12 +43,10 @@ NUM_OUTPUT = 525
47
  # set maximum chunk overlap
48
  CHUNK_OVERLAP_RATION = 0.2
49
 
50
- # TODO: use the following prompt to format the answer at the end of the context prompt
51
  ANSWER_FORMAT = """
52
- Use the following example format for your answer:
53
  [FORMAT]
54
- Answer:
55
- The answer to the user question.
56
  Reference:
57
  The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
58
  [END_FORMAT]
@@ -200,6 +194,7 @@ class LlamaCustom:
200
  # condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
201
  # # verbose=True,
202
  # )
 
203
  response = query_engine.query(query_str)
204
  # response = chat_engine.chat(message=query_str, chat_history=chat_history)
205
 
 
18
 
19
  # llama index
20
  from llama_index.core import (
 
 
21
  VectorStoreIndex,
 
 
22
  PromptTemplate,
23
  )
24
  from llama_index.core.llms import (
 
43
  # set maximum chunk overlap
44
  CHUNK_OVERLAP_RATION = 0.2
45
 
 
46
  ANSWER_FORMAT = """
47
+ Provide the answer to the user question in the following format:
48
  [FORMAT]
49
+ Your answer to the user question above.
 
50
  Reference:
51
  The list of references (such as page number, title, chapter, section) to the specific sections of the documents that support your answer.
52
  [END_FORMAT]
 
194
  # condense_prompt=CHAT_ENGINE_CONDENSE_PROMPT_TEMPLATE,
195
  # # verbose=True,
196
  # )
197
+
198
  response = query_engine.query(query_str)
199
  # response = chat_engine.chat(message=query_str, chat_history=chat_history)
200
 
models/llamaCustomV2.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from llama_index.core import VectorStoreIndex
4
+ from llama_index.core.query_pipeline import (
5
+ QueryPipeline,
6
+ InputComponent,
7
+ ArgPackComponent,
8
+ )
9
+ from llama_index.core.prompts import PromptTemplate
10
+ from llama_index.llms.openai import OpenAI
11
+ from llama_index.postprocessor.colbert_rerank import ColbertRerank
12
+ from typing import Any, Dict, List, Optional
13
+ from llama_index.core.bridge.pydantic import Field
14
+ from llama_index.core.llms import ChatMessage
15
+ from llama_index.core.query_pipeline import CustomQueryComponent
16
+ from llama_index.core.schema import NodeWithScore
17
+ from llama_index.core.memory import ChatMemoryBuffer
18
+
19
+
20
+ llm = OpenAI(
21
+ model="gpt-3.5-turbo-0125",
22
+ api_key=os.getenv("OPENAI_API_KEY"),
23
+ )
24
+
25
+ # First, we create an input component to capture the user query
26
+ input_component = InputComponent()
27
+
28
+ # Next, we use the LLM to rewrite a user query
29
+ rewrite = (
30
+ "Please write a query to a semantic search engine using the current conversation.\n"
31
+ "\n"
32
+ "\n"
33
+ "{chat_history_str}"
34
+ "\n"
35
+ "\n"
36
+ "Latest message: {query_str}\n"
37
+ 'Query:"""\n'
38
+ )
39
+ rewrite_template = PromptTemplate(rewrite)
40
+
41
+ # we will retrieve two times, so we need to pack the retrieved nodes into a single list
42
+ argpack_component = ArgPackComponent()
43
+
44
+ # then postprocess/rerank with Colbert
45
+ reranker = ColbertRerank(top_n=3)
46
+
47
+ DEFAULT_CONTEXT_PROMPT = (
48
+ "Here is some context that may be relevant:\n"
49
+ "-----\n"
50
+ "{node_context}\n"
51
+ "-----\n"
52
+ "Please write a response to the following question, using the above context:\n"
53
+ "{query_str}\n"
54
+ "Please formate your response in the following way:\n"
55
+ "Your answer here.\n"
56
+ "Reference:\n"
57
+ " Your references here (e.g. page numbers, titles, etc.).\n"
58
+ )
59
+
60
+
61
+ class ResponseWithChatHistory(CustomQueryComponent):
62
+ llm: OpenAI = Field(..., description="OpenAI LLM")
63
+ system_prompt: Optional[str] = Field(
64
+ default=None, description="System prompt to use for the LLM"
65
+ )
66
+ context_prompt: str = Field(
67
+ default=DEFAULT_CONTEXT_PROMPT,
68
+ description="Context prompt to use for the LLM",
69
+ )
70
+
71
+ def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
72
+ """Validate component inputs during run_component."""
73
+ # NOTE: this is OPTIONAL but we show you where to do validation as an example
74
+ return input
75
+
76
+ @property
77
+ def _input_keys(self) -> set:
78
+ """Input keys dict."""
79
+ # NOTE: These are required inputs. If you have optional inputs please override
80
+ # `optional_input_keys_dict`
81
+ return {"chat_history", "nodes", "query_str"}
82
+
83
+ @property
84
+ def _output_keys(self) -> set:
85
+ return {"response"}
86
+
87
+ def _prepare_context(
88
+ self,
89
+ chat_history: List[ChatMessage],
90
+ nodes: List[NodeWithScore],
91
+ query_str: str,
92
+ ) -> List[ChatMessage]:
93
+ node_context = ""
94
+ for idx, node in enumerate(nodes):
95
+ node_text = node.get_content(metadata_mode="llm")
96
+ node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
97
+
98
+ formatted_context = self.context_prompt.format(
99
+ node_context=node_context, query_str=query_str
100
+ )
101
+ user_message = ChatMessage(role="user", content=formatted_context)
102
+
103
+ chat_history.append(user_message)
104
+
105
+ if self.system_prompt is not None:
106
+ chat_history = [
107
+ ChatMessage(role="system", content=self.system_prompt)
108
+ ] + chat_history
109
+
110
+ return chat_history
111
+
112
+ def _run_component(self, **kwargs) -> Dict[str, Any]:
113
+ """Run the component."""
114
+ chat_history = kwargs["chat_history"]
115
+ nodes = kwargs["nodes"]
116
+ query_str = kwargs["query_str"]
117
+
118
+ prepared_context = self._prepare_context(chat_history, nodes, query_str)
119
+
120
+ response = llm.chat(prepared_context)
121
+
122
+ return {"response": response}
123
+
124
+ async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
125
+ """Run the component asynchronously."""
126
+ # NOTE: Optional, but async LLM calls are easy to implement
127
+ chat_history = kwargs["chat_history"]
128
+ nodes = kwargs["nodes"]
129
+ query_str = kwargs["query_str"]
130
+
131
+ prepared_context = self._prepare_context(chat_history, nodes, query_str)
132
+
133
+ response = await llm.achat(prepared_context)
134
+
135
+ return {"response": response}
136
+
137
+
138
+ class LlamaCustomV2:
139
+ response_component = ResponseWithChatHistory(
140
+ llm=llm,
141
+ system_prompt=(
142
+ "You are a Q&A system. You will be provided with the previous chat history, "
143
+ "as well as possibly relevant context, to assist in answering a user message."
144
+ ),
145
+ )
146
+
147
+ def __init__(self, model_name: str, index: VectorStoreIndex):
148
+ self.model_name = model_name
149
+ self.index = index
150
+ self.retriever = index.as_retriever()
151
+ self.chat_mode = "condense_plus_context"
152
+ self.memory = ChatMemoryBuffer.from_defaults()
153
+ self.verbose = True
154
+ self._build_pipeline()
155
+
156
+ def _build_pipeline(self):
157
+ self.pipeline = QueryPipeline(
158
+ modules={
159
+ "input": input_component,
160
+ "rewrite_template": rewrite_template,
161
+ "llm": llm,
162
+ "rewrite_retriever": self.retriever,
163
+ "query_retriever": self.retriever,
164
+ "join": argpack_component,
165
+ "reranker": reranker,
166
+ "response_component": self.response_component,
167
+ },
168
+ verbose=self.verbose,
169
+ )
170
+ # run both retrievers -- once with the hallucinated query, once with the real query
171
+ self.pipeline.add_link(
172
+ "input", "rewrite_template", src_key="query_str", dest_key="query_str"
173
+ )
174
+ self.pipeline.add_link(
175
+ "input",
176
+ "rewrite_template",
177
+ src_key="chat_history_str",
178
+ dest_key="chat_history_str",
179
+ )
180
+ self.pipeline.add_link("rewrite_template", "llm")
181
+ self.pipeline.add_link("llm", "rewrite_retriever")
182
+ self.pipeline.add_link("input", "query_retriever", src_key="query_str")
183
+
184
+ # each input to the argpack component needs a dest key -- it can be anything
185
+ # then, the argpack component will pack all the inputs into a single list
186
+ self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
187
+ self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
188
+
189
+ # reranker needs the packed nodes and the query string
190
+ self.pipeline.add_link("join", "reranker", dest_key="nodes")
191
+ self.pipeline.add_link(
192
+ "input", "reranker", src_key="query_str", dest_key="query_str"
193
+ )
194
+
195
+ # synthesizer needs the reranked nodes and query str
196
+ self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
197
+ self.pipeline.add_link(
198
+ "input", "response_component", src_key="query_str", dest_key="query_str"
199
+ )
200
+ self.pipeline.add_link(
201
+ "input",
202
+ "response_component",
203
+ src_key="chat_history",
204
+ dest_key="chat_history",
205
+ )
206
+
207
+ def get_response(self, query_str: str, chat_history: List[ChatMessage]):
208
+ chat_history = self.memory.get()
209
+ char_history_str = "\n".join([str(x) for x in chat_history])
210
+
211
+ response = self.pipeline.run(
212
+ query_str=query_str,
213
+ chat_history=chat_history,
214
+ chat_history_str=char_history_str,
215
+ )
216
+
217
+ user_msg = ChatMessage(role="user", content=query_str)
218
+ print("user_msg: ", str(user_msg))
219
+ print("response: ", str(response.message))
220
+ self.memory.put(user_msg)
221
+ self.memory.put(response.message)
222
+
223
+ return str(response.message)
224
+
225
+ def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
226
+ response = self.get_response(query_str=query_str, chat_history=chat_history)
227
+ for word in response.split():
228
+ yield word + " "
229
+ time.sleep(0.05)
models/llms.py CHANGED
@@ -35,6 +35,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
35
  llm_gpt_3_5_turbo_0125 = OpenAI(
36
  model=model_name,
37
  api_key=st.session_state.openai_api_key,
 
38
  )
39
 
40
  return llm_gpt_3_5_turbo_0125
@@ -45,6 +46,7 @@ def load_llm(model_name: str, source: str = "huggingface"):
45
  is_chat_model=True,
46
  additional_kwargs={"max_new_tokens": 250},
47
  prompt_key=st.session_state.replicate_api_token,
 
48
  )
49
 
50
  return llm_llama_13b_v2_replicate
 
35
  llm_gpt_3_5_turbo_0125 = OpenAI(
36
  model=model_name,
37
  api_key=st.session_state.openai_api_key,
38
+ temperature=0.0,
39
  )
40
 
41
  return llm_gpt_3_5_turbo_0125
 
46
  is_chat_model=True,
47
  additional_kwargs={"max_new_tokens": 250},
48
  prompt_key=st.session_state.replicate_api_token,
49
+ temperature=0.0,
50
  )
51
 
52
  return llm_llama_13b_v2_replicate
pages/llama_custom_demo.py CHANGED
@@ -7,6 +7,7 @@ from typing import List
7
  from models.llms import load_llm, integrated_llms
8
  from models.embeddings import hf_embed_model, openai_embed_model
9
  from models.llamaCustom import LlamaCustom
 
10
 
11
  # from models.vector_database import pinecone_vector_store
12
  from utils.chatbox import show_previous_messages, show_chat_input
@@ -209,6 +210,7 @@ with tab1:
209
 
210
  st.write("Finishing Up ...")
211
  llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
 
212
  st.session_state.llama_custom = llama_custom
213
 
214
  status.update(label="Ready to query!", state="complete", expanded=False)
 
7
  from models.llms import load_llm, integrated_llms
8
  from models.embeddings import hf_embed_model, openai_embed_model
9
  from models.llamaCustom import LlamaCustom
10
+ from models.llamaCustomV2 import LlamaCustomV2
11
 
12
  # from models.vector_database import pinecone_vector_store
13
  from utils.chatbox import show_previous_messages, show_chat_input
 
210
 
211
  st.write("Finishing Up ...")
212
  llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
213
+ # llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index)
214
  st.session_state.llama_custom = llama_custom
215
 
216
  status.update(label="Ready to query!", state="complete", expanded=False)
requirements.txt CHANGED
@@ -16,4 +16,5 @@ llama-index-vector-stores-pinecone
16
  pinecone-client>=3.0.0
17
  replicate>=0.25.1
18
  llama-index-llms-replicate
19
- sentence-transformers>=2.6.1
 
 
16
  pinecone-client>=3.0.0
17
  replicate>=0.25.1
18
  llama-index-llms-replicate
19
+ sentence-transformers>=2.6.1
20
+ llama-index-postprocessor-colbert-rerank