Spaces:
Runtime error
Runtime error
Commit
·
dec332b
1
Parent(s):
3557a96
added ux, vision_api, qna.txt
Browse files- .gitignore +5 -2
- archive/init_setup.py +27 -0
- raw_documents/qna.txt +3 -0
- streamlit_app.py +65 -42
- ux/add_logo.py +50 -0
- ux/apps.py +31 -0
- ux/components.py +279 -0
- ux/styles.py +143 -0
- ux/trulens_logo.svg +44 -0
- vision_api.py +38 -0
.gitignore
CHANGED
|
@@ -4,6 +4,9 @@
|
|
| 4 |
results/
|
| 5 |
|
| 6 |
*.sqlite
|
| 7 |
-
ux/
|
| 8 |
data/
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
results/
|
| 5 |
|
| 6 |
*.sqlite
|
|
|
|
| 7 |
data/
|
| 8 |
+
|
| 9 |
+
notebooks/test_model
|
| 10 |
+
screenshot_questions/
|
| 11 |
+
|
| 12 |
+
# ux/
|
archive/init_setup.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import main
|
| 2 |
+
|
| 3 |
+
import pkg_resources
|
| 4 |
+
import shutil
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
### To trigger trulens evaluation
|
| 8 |
+
main.main()
|
| 9 |
+
|
| 10 |
+
### Finally, start streamlit app
|
| 11 |
+
leaderboard_path = pkg_resources.resource_filename(
|
| 12 |
+
"trulens_eval", "Leaderboard.py"
|
| 13 |
+
)
|
| 14 |
+
evaluation_path = pkg_resources.resource_filename(
|
| 15 |
+
"trulens_eval", "pages/Evaluations.py"
|
| 16 |
+
)
|
| 17 |
+
ux_path = pkg_resources.resource_filename(
|
| 18 |
+
"trulens_eval", "ux"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
os.makedirs("./pages", exist_ok=True)
|
| 22 |
+
shutil.copyfile(leaderboard_path, os.path.join("./pages", "1_Leaderboard.py"))
|
| 23 |
+
shutil.copyfile(evaluation_path, os.path.join("./pages", "2_Evaluations.py"))
|
| 24 |
+
|
| 25 |
+
if os.path.exists("./ux"):
|
| 26 |
+
shutil.rmtree("./ux")
|
| 27 |
+
shutil.copytree(ux_path, "./ux")
|
raw_documents/qna.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8b44d78e6dec3a285124f0a449ff5bae699ab4ff98ae3826a33a8eb4f182334
|
| 3 |
+
size 1804
|
streamlit_app.py
CHANGED
|
@@ -3,13 +3,11 @@ from streamlit_feedback import streamlit_feedback
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pandas as pd
|
| 6 |
-
import
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
import openai
|
| 9 |
-
|
| 10 |
-
# from openai import OpenAI
|
| 11 |
from llama_index.llms import OpenAI
|
| 12 |
-
|
| 13 |
from llama_index import SimpleDirectoryReader
|
| 14 |
from llama_index import Document
|
| 15 |
from llama_index import VectorStoreIndex
|
|
@@ -17,38 +15,17 @@ from llama_index import ServiceContext
|
|
| 17 |
from llama_index.embeddings import HuggingFaceEmbedding
|
| 18 |
from llama_index.memory import ChatMemoryBuffer
|
| 19 |
|
| 20 |
-
import
|
| 21 |
-
import shutil
|
| 22 |
-
import main
|
| 23 |
-
|
| 24 |
-
### To trigger trulens evaluation
|
| 25 |
-
main.main()
|
| 26 |
-
|
| 27 |
-
### Finally, start streamlit app
|
| 28 |
-
leaderboard_path = pkg_resources.resource_filename(
|
| 29 |
-
"trulens_eval", "Leaderboard.py"
|
| 30 |
-
)
|
| 31 |
-
evaluation_path = pkg_resources.resource_filename(
|
| 32 |
-
"trulens_eval", "pages/Evaluations.py"
|
| 33 |
-
)
|
| 34 |
-
ux_path = pkg_resources.resource_filename(
|
| 35 |
-
"trulens_eval", "ux"
|
| 36 |
-
)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
shutil.copyfile(leaderboard_path, os.path.join("./pages", "1_Leaderboard.py"))
|
| 40 |
-
shutil.copyfile(evaluation_path, os.path.join("./pages", "2_Evaluations.py"))
|
| 41 |
-
|
| 42 |
-
if os.path.exists("./ux"):
|
| 43 |
-
shutil.rmtree("./ux")
|
| 44 |
-
shutil.copytree(ux_path, "./ux")
|
| 45 |
|
| 46 |
# App title
|
| 47 |
st.set_page_config(page_title="💬 Open AI Chatbot")
|
| 48 |
openai_api = os.getenv("OPENAI_API_KEY")
|
| 49 |
|
| 50 |
# "./raw_documents/HI_Knowledge_Base.pdf"
|
| 51 |
-
input_files = ["./raw_documents/HI Chapter Summary Version 1.3.pdf"
|
|
|
|
| 52 |
embedding_model = "BAAI/bge-small-en-v1.5"
|
| 53 |
system_content = ("You are a helpful study assistant. "
|
| 54 |
"You do not respond as 'User' or pretend to be 'User'. "
|
|
@@ -104,25 +81,25 @@ with st.sidebar:
|
|
| 104 |
st.markdown("📖 Reach out to SakiMilo to learn how to create this app!")
|
| 105 |
|
| 106 |
if "init" not in st.session_state.keys():
|
| 107 |
-
st.session_state.init = {"
|
| 108 |
st.session_state.feedback = False
|
| 109 |
|
| 110 |
# Store LLM generated responses
|
| 111 |
if "messages" not in st.session_state.keys():
|
| 112 |
st.session_state.messages = [{"role": "assistant",
|
| 113 |
-
"content": "How may I assist you today?"
|
|
|
|
| 114 |
|
| 115 |
if "feedback_key" not in st.session_state:
|
| 116 |
st.session_state.feedback_key = 0
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
with st.chat_message(message["role"]):
|
| 121 |
-
st.write(message["content"])
|
| 122 |
|
| 123 |
def clear_chat_history():
|
| 124 |
st.session_state.messages = [{"role": "assistant",
|
| 125 |
-
"content": "How may I assist you today?"
|
|
|
|
| 126 |
chat_engine = get_query_engine(input_files=input_files,
|
| 127 |
llm_model=selected_model,
|
| 128 |
temperature=temperature,
|
|
@@ -187,23 +164,66 @@ def handle_feedback(user_response):
|
|
| 187 |
st.toast("✔️ Feedback received!")
|
| 188 |
st.session_state.feedback = False
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
# Warm start
|
| 191 |
-
if st.session_state.init["
|
| 192 |
clear_chat_history()
|
| 193 |
-
st.session_state.init["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# User-provided prompt
|
| 196 |
if prompt := st.chat_input(disabled=not openai_api):
|
| 197 |
client = OpenAI()
|
| 198 |
-
st.session_state.messages.append({"role": "user",
|
|
|
|
|
|
|
| 199 |
with st.chat_message("user"):
|
| 200 |
st.write(prompt)
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
# Generate a new response if last message is not from assistant
|
| 203 |
if st.session_state.messages[-1]["role"] != "assistant":
|
| 204 |
with st.chat_message("assistant"):
|
| 205 |
with st.spinner("Thinking..."):
|
| 206 |
-
# response = generate_llm_response(client, prompt)
|
| 207 |
response = generate_llm_response(prompt)
|
| 208 |
placeholder = st.empty()
|
| 209 |
full_response = ""
|
|
@@ -212,9 +232,12 @@ if st.session_state.messages[-1]["role"] != "assistant":
|
|
| 212 |
placeholder.markdown(full_response)
|
| 213 |
placeholder.markdown(full_response)
|
| 214 |
|
| 215 |
-
message = {"role": "assistant",
|
|
|
|
|
|
|
| 216 |
st.session_state.messages.append(message)
|
| 217 |
|
|
|
|
| 218 |
if st.session_state.feedback:
|
| 219 |
result = streamlit_feedback(
|
| 220 |
feedback_type="thumbs",
|
|
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pandas as pd
|
| 6 |
+
import base64
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import nest_asyncio
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
from llama_index.llms import OpenAI
|
|
|
|
| 11 |
from llama_index import SimpleDirectoryReader
|
| 12 |
from llama_index import Document
|
| 13 |
from llama_index import VectorStoreIndex
|
|
|
|
| 15 |
from llama_index.embeddings import HuggingFaceEmbedding
|
| 16 |
from llama_index.memory import ChatMemoryBuffer
|
| 17 |
|
| 18 |
+
from vision_api import get_transcribed_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
nest_asyncio.apply()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# App title
|
| 23 |
st.set_page_config(page_title="💬 Open AI Chatbot")
|
| 24 |
openai_api = os.getenv("OPENAI_API_KEY")
|
| 25 |
|
| 26 |
# "./raw_documents/HI_Knowledge_Base.pdf"
|
| 27 |
+
input_files = ["./raw_documents/HI Chapter Summary Version 1.3.pdf",
|
| 28 |
+
"./raw_documents/qna.txt"]
|
| 29 |
embedding_model = "BAAI/bge-small-en-v1.5"
|
| 30 |
system_content = ("You are a helpful study assistant. "
|
| 31 |
"You do not respond as 'User' or pretend to be 'User'. "
|
|
|
|
| 81 |
st.markdown("📖 Reach out to SakiMilo to learn how to create this app!")
|
| 82 |
|
| 83 |
if "init" not in st.session_state.keys():
|
| 84 |
+
st.session_state.init = {"warm_started": "No"}
|
| 85 |
st.session_state.feedback = False
|
| 86 |
|
| 87 |
# Store LLM generated responses
|
| 88 |
if "messages" not in st.session_state.keys():
|
| 89 |
st.session_state.messages = [{"role": "assistant",
|
| 90 |
+
"content": "How may I assist you today?",
|
| 91 |
+
"type": "text"}]
|
| 92 |
|
| 93 |
if "feedback_key" not in st.session_state:
|
| 94 |
st.session_state.feedback_key = 0
|
| 95 |
|
| 96 |
+
if "release_file" not in st.session_state:
|
| 97 |
+
st.session_state.release_file = "false"
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def clear_chat_history():
|
| 100 |
st.session_state.messages = [{"role": "assistant",
|
| 101 |
+
"content": "How may I assist you today?",
|
| 102 |
+
"type": "text"}]
|
| 103 |
chat_engine = get_query_engine(input_files=input_files,
|
| 104 |
llm_model=selected_model,
|
| 105 |
temperature=temperature,
|
|
|
|
| 164 |
st.toast("✔️ Feedback received!")
|
| 165 |
st.session_state.feedback = False
|
| 166 |
|
| 167 |
+
def handle_image_upload():
|
| 168 |
+
st.session_state.release_file = "true"
|
| 169 |
+
|
| 170 |
# Warm start
|
| 171 |
+
if st.session_state.init["warm_started"] == "No":
|
| 172 |
clear_chat_history()
|
| 173 |
+
st.session_state.init["warm_started"] = "Yes"
|
| 174 |
+
|
| 175 |
+
# Image upload option
|
| 176 |
+
with st.sidebar:
|
| 177 |
+
image_file = st.file_uploader("Upload your image here...",
|
| 178 |
+
type=["png", "jpeg", "jpg"],
|
| 179 |
+
on_change=handle_image_upload)
|
| 180 |
+
|
| 181 |
+
if st.session_state.release_file == "true" and image_file:
|
| 182 |
+
with st.spinner("Uploading..."):
|
| 183 |
+
b64string = base64.b64encode(image_file.read()).decode('utf-8')
|
| 184 |
+
message = {
|
| 185 |
+
"role": "user",
|
| 186 |
+
"content": b64string,
|
| 187 |
+
"type": "image"}
|
| 188 |
+
st.session_state.messages.append(message)
|
| 189 |
+
|
| 190 |
+
transcribed_msg = get_transcribed_text(b64string)
|
| 191 |
+
message = {
|
| 192 |
+
"role": "admin",
|
| 193 |
+
"content": transcribed_msg,
|
| 194 |
+
"type": "text"}
|
| 195 |
+
st.session_state.messages.append(message)
|
| 196 |
+
st.session_state.release_file = "false"
|
| 197 |
+
|
| 198 |
+
# Display or clear chat messages
|
| 199 |
+
for message in st.session_state.messages:
|
| 200 |
+
if message["role"] == "admin":
|
| 201 |
+
continue
|
| 202 |
+
with st.chat_message(message["role"]):
|
| 203 |
+
if message["type"] == "text":
|
| 204 |
+
st.write(message["content"])
|
| 205 |
+
elif message["type"] == "image":
|
| 206 |
+
img_io = BytesIO(base64.b64decode(message["content"].encode("utf-8")))
|
| 207 |
+
st.image(img_io)
|
| 208 |
|
| 209 |
# User-provided prompt
|
| 210 |
if prompt := st.chat_input(disabled=not openai_api):
|
| 211 |
client = OpenAI()
|
| 212 |
+
st.session_state.messages.append({"role": "user",
|
| 213 |
+
"content": prompt,
|
| 214 |
+
"type": "text"})
|
| 215 |
with st.chat_message("user"):
|
| 216 |
st.write(prompt)
|
| 217 |
|
| 218 |
+
# Retrieve text prompt from image submission
|
| 219 |
+
if prompt is None and \
|
| 220 |
+
st.session_state.messages[-1]["role"] == "admin":
|
| 221 |
+
prompt = st.session_state.messages[-1]["content"]
|
| 222 |
+
|
| 223 |
# Generate a new response if last message is not from assistant
|
| 224 |
if st.session_state.messages[-1]["role"] != "assistant":
|
| 225 |
with st.chat_message("assistant"):
|
| 226 |
with st.spinner("Thinking..."):
|
|
|
|
| 227 |
response = generate_llm_response(prompt)
|
| 228 |
placeholder = st.empty()
|
| 229 |
full_response = ""
|
|
|
|
| 232 |
placeholder.markdown(full_response)
|
| 233 |
placeholder.markdown(full_response)
|
| 234 |
|
| 235 |
+
message = {"role": "assistant",
|
| 236 |
+
"content": full_response,
|
| 237 |
+
"type": "text"}
|
| 238 |
st.session_state.messages.append(message)
|
| 239 |
|
| 240 |
+
# Trigger feedback
|
| 241 |
if st.session_state.feedback:
|
| 242 |
result = streamlit_feedback(
|
| 243 |
feedback_type="thumbs",
|
ux/add_logo.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
|
| 3 |
+
import pkg_resources
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
from trulens_eval import __package__
|
| 7 |
+
from trulens_eval import __version__
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def add_logo_and_style_overrides():
|
| 11 |
+
logo = open(
|
| 12 |
+
pkg_resources.resource_filename('trulens_eval', 'ux/trulens_logo.svg'),
|
| 13 |
+
"rb"
|
| 14 |
+
).read()
|
| 15 |
+
|
| 16 |
+
logo_encoded = base64.b64encode(logo).decode()
|
| 17 |
+
st.markdown(
|
| 18 |
+
f"""
|
| 19 |
+
<style>
|
| 20 |
+
[data-testid="stSidebarNav"] {{
|
| 21 |
+
background-image: url('data:image/svg+xml;base64,{logo_encoded}');
|
| 22 |
+
background-repeat: no-repeat;
|
| 23 |
+
background-size: 300px auto;
|
| 24 |
+
padding-top: 50px;
|
| 25 |
+
background-position: 20px 20px;
|
| 26 |
+
}}
|
| 27 |
+
[data-testid="stSidebarNav"]::before {{
|
| 28 |
+
margin-left: 20px;
|
| 29 |
+
margin-top: 20px;
|
| 30 |
+
font-size: 30px;
|
| 31 |
+
position: relative;
|
| 32 |
+
top: 100px;
|
| 33 |
+
}}
|
| 34 |
+
[data-testid="stSidebarNav"]::after {{
|
| 35 |
+
margin-left: 20px;
|
| 36 |
+
color: #aaaaaa;
|
| 37 |
+
content: "{__package__} {__version__}";
|
| 38 |
+
font-size: 10pt;
|
| 39 |
+
}}
|
| 40 |
+
|
| 41 |
+
/* For list items in st.dataframe */
|
| 42 |
+
#portal .clip-region .boe-bubble {{
|
| 43 |
+
height: auto;
|
| 44 |
+
border-radius: 4px;
|
| 45 |
+
padding: 8px;
|
| 46 |
+
}}
|
| 47 |
+
</style>
|
| 48 |
+
""",
|
| 49 |
+
unsafe_allow_html=True,
|
| 50 |
+
)
|
ux/apps.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code in support of the Apps.py page.
|
| 2 |
+
|
| 3 |
+
from typing import Any, ClassVar, Optional
|
| 4 |
+
|
| 5 |
+
import pydantic
|
| 6 |
+
|
| 7 |
+
from trulens_eval.app import App
|
| 8 |
+
from trulens_eval.utils.serial import JSON
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChatRecord(pydantic.BaseModel):
|
| 12 |
+
|
| 13 |
+
model_config: ClassVar[dict] = dict(
|
| 14 |
+
arbitrary_types_allowed = True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Human input
|
| 18 |
+
human: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
# Computer response
|
| 21 |
+
computer: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
# Jsonified record. Available only after the app is run on human input and
|
| 24 |
+
# produced a computer output.
|
| 25 |
+
record_json: Optional[JSON] = None
|
| 26 |
+
|
| 27 |
+
# The final app state for continuing the session.
|
| 28 |
+
app: App
|
| 29 |
+
|
| 30 |
+
# The state of the app as was when this record was produced.
|
| 31 |
+
app_json: JSON
|
ux/components.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
+
from trulens_eval.app import ComponentView
|
| 9 |
+
from trulens_eval.keys import REDACTED_VALUE
|
| 10 |
+
from trulens_eval.keys import should_redact_key
|
| 11 |
+
from trulens_eval.schema import Metadata
|
| 12 |
+
from trulens_eval.schema import Record
|
| 13 |
+
from trulens_eval.schema import RecordAppCall
|
| 14 |
+
from trulens_eval.schema import Select
|
| 15 |
+
from trulens_eval.utils.containers import is_empty
|
| 16 |
+
from trulens_eval.utils.json import jsonify
|
| 17 |
+
from trulens_eval.utils.pyschema import CLASS_INFO
|
| 18 |
+
from trulens_eval.utils.pyschema import is_noserio
|
| 19 |
+
from trulens_eval.utils.serial import GetItemOrAttribute
|
| 20 |
+
from trulens_eval.utils.serial import JSON_BASES
|
| 21 |
+
from trulens_eval.utils.serial import Lens
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def write_or_json(st, obj):
|
| 25 |
+
"""
|
| 26 |
+
Dispatch either st.json or st.write depending on content of `obj`. If it is
|
| 27 |
+
a string that can parses into strictly json (dict), use st.json, otherwise
|
| 28 |
+
use st.write.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
if isinstance(obj, str):
|
| 32 |
+
try:
|
| 33 |
+
content = json.loads(obj)
|
| 34 |
+
if not isinstance(content, str):
|
| 35 |
+
st.json(content)
|
| 36 |
+
else:
|
| 37 |
+
st.write(content)
|
| 38 |
+
|
| 39 |
+
except BaseException:
|
| 40 |
+
st.write(obj)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def copy_to_clipboard(path, *args, **kwargs):
|
| 44 |
+
st.session_state.clipboard = str(path)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def draw_selector_button(path) -> None:
|
| 48 |
+
st.button(
|
| 49 |
+
key=str(random.random()),
|
| 50 |
+
label=f"{Select.render_for_dashboard(path)}",
|
| 51 |
+
on_click=copy_to_clipboard,
|
| 52 |
+
args=(path,)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def render_selector_markdown(path) -> str:
|
| 57 |
+
return f"[`{Select.render_for_dashboard(path)}`]"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def render_call_frame(frame: RecordAppCall, path=None) -> str: # markdown
|
| 61 |
+
path = path or frame.path
|
| 62 |
+
|
| 63 |
+
return (
|
| 64 |
+
f"__{frame.method.name}__ (__{frame.method.obj.cls.module.module_name}.{frame.method.obj.cls.name}__)"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def dict_to_md(dictionary: dict) -> str:
|
| 69 |
+
if len(dictionary) == 0:
|
| 70 |
+
return "No metadata."
|
| 71 |
+
mdheader = "|"
|
| 72 |
+
mdseparator = "|"
|
| 73 |
+
mdbody = "|"
|
| 74 |
+
for key, value in dictionary.items():
|
| 75 |
+
mdheader = mdheader + str(key) + "|"
|
| 76 |
+
mdseparator = mdseparator + "-------|"
|
| 77 |
+
mdbody = mdbody + str(value) + "|"
|
| 78 |
+
mdtext = mdheader + "\n" + mdseparator + "\n" + mdbody
|
| 79 |
+
return mdtext
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def draw_metadata(metadata: Metadata) -> str:
|
| 83 |
+
if isinstance(metadata, Dict):
|
| 84 |
+
return dict_to_md(metadata)
|
| 85 |
+
else:
|
| 86 |
+
return str(metadata)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def draw_call(call: RecordAppCall) -> None:
|
| 90 |
+
top = call.stack[-1]
|
| 91 |
+
|
| 92 |
+
path = Select.for_record(
|
| 93 |
+
top.path._append(
|
| 94 |
+
step=GetItemOrAttribute(item_or_attribute=top.method.name)
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
with st.expander(label=f"Call " + render_call_frame(top, path=path) + " " +
|
| 99 |
+
render_selector_markdown(path)):
|
| 100 |
+
|
| 101 |
+
args = call.args
|
| 102 |
+
rets = call.rets
|
| 103 |
+
|
| 104 |
+
for frame in call.stack[::-1][1:]:
|
| 105 |
+
st.write("Via " + render_call_frame(frame, path=path))
|
| 106 |
+
|
| 107 |
+
st.subheader(f"Inputs {render_selector_markdown(path.args)}")
|
| 108 |
+
if isinstance(args, Dict):
|
| 109 |
+
st.json(args)
|
| 110 |
+
else:
|
| 111 |
+
st.write(args)
|
| 112 |
+
|
| 113 |
+
st.subheader(f"Outputs {render_selector_markdown(path.rets)}")
|
| 114 |
+
if isinstance(rets, Dict):
|
| 115 |
+
st.json(rets)
|
| 116 |
+
else:
|
| 117 |
+
st.write(rets)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def draw_calls(record: Record, index: int) -> None:
|
| 121 |
+
"""
|
| 122 |
+
Draw the calls recorded in a `record`.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
calls = record.calls
|
| 126 |
+
|
| 127 |
+
app_step = 0
|
| 128 |
+
|
| 129 |
+
for call in calls:
|
| 130 |
+
app_step += 1
|
| 131 |
+
|
| 132 |
+
if app_step != index:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
draw_call(call)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def draw_prompt_info(query: Lens, component: ComponentView) -> None:
|
| 139 |
+
prompt_details_json = jsonify(component.json, skip_specials=True)
|
| 140 |
+
|
| 141 |
+
st.caption(f"Prompt details")
|
| 142 |
+
|
| 143 |
+
path = Select.for_app(query)
|
| 144 |
+
|
| 145 |
+
prompt_types = {
|
| 146 |
+
k: v for k, v in prompt_details_json.items() if (v is not None) and
|
| 147 |
+
not is_empty(v) and not is_noserio(v) and k != CLASS_INFO
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
for key, value in prompt_types.items():
|
| 151 |
+
with st.expander(key.capitalize() + " " +
|
| 152 |
+
render_selector_markdown(getattr(path, key)),
|
| 153 |
+
expanded=True):
|
| 154 |
+
|
| 155 |
+
if isinstance(value, (Dict, List)):
|
| 156 |
+
st.write(value)
|
| 157 |
+
else:
|
| 158 |
+
if isinstance(value, str) and len(value) > 32:
|
| 159 |
+
st.text(value)
|
| 160 |
+
else:
|
| 161 |
+
st.write(value)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def draw_llm_info(query: Lens, component: ComponentView) -> None:
|
| 165 |
+
llm_details_json = component.json
|
| 166 |
+
|
| 167 |
+
st.subheader(f"*LLM Details*")
|
| 168 |
+
# path_str = str(query)
|
| 169 |
+
# st.text(path_str[:-4])
|
| 170 |
+
|
| 171 |
+
llm_kv = {
|
| 172 |
+
k: v for k, v in llm_details_json.items() if (v is not None) and
|
| 173 |
+
not is_empty(v) and not is_noserio(v) and k != CLASS_INFO
|
| 174 |
+
}
|
| 175 |
+
# CSS to inject contained in a string
|
| 176 |
+
hide_table_row_index = """
|
| 177 |
+
<style>
|
| 178 |
+
thead tr th:first-child {display:none}
|
| 179 |
+
tbody th {display:none}
|
| 180 |
+
</style>
|
| 181 |
+
"""
|
| 182 |
+
df = pd.DataFrame.from_dict(llm_kv, orient='index').transpose()
|
| 183 |
+
|
| 184 |
+
# Redact any column whose name indicates it might be a secret.
|
| 185 |
+
for col in df.columns:
|
| 186 |
+
if should_redact_key(col):
|
| 187 |
+
df[col] = REDACTED_VALUE
|
| 188 |
+
|
| 189 |
+
# TODO: What about columns not indicating a secret but some values do
|
| 190 |
+
# indicate it as per `should_redact_value` ?
|
| 191 |
+
|
| 192 |
+
# Iterate over each column of the DataFrame
|
| 193 |
+
for column in df.columns:
|
| 194 |
+
path = getattr(Select.for_app(query), str(column))
|
| 195 |
+
# Check if any cell in the column is a dictionary
|
| 196 |
+
|
| 197 |
+
if any(isinstance(cell, dict) for cell in df[column]):
|
| 198 |
+
# Create new columns for each key in the dictionary
|
| 199 |
+
new_columns = df[column].apply(
|
| 200 |
+
lambda x: pd.Series(x) if isinstance(x, dict) else pd.Series()
|
| 201 |
+
)
|
| 202 |
+
new_columns.columns = [
|
| 203 |
+
f"{key} {render_selector_markdown(path)}"
|
| 204 |
+
for key in new_columns.columns
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
# Remove extra zeros after the decimal point
|
| 208 |
+
new_columns = new_columns.applymap(
|
| 209 |
+
lambda x: '{0:g}'.format(x) if isinstance(x, float) else x
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Add the new columns to the original DataFrame
|
| 213 |
+
df = pd.concat([df.drop(column, axis=1), new_columns], axis=1)
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
# TODO: add selectors to the output here
|
| 217 |
+
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
# Inject CSS with Markdown
|
| 221 |
+
|
| 222 |
+
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
| 223 |
+
st.table(df)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def draw_agent_info(query: Lens, component: ComponentView) -> None:
|
| 227 |
+
# copy of draw_prompt_info
|
| 228 |
+
# TODO: dedup
|
| 229 |
+
prompt_details_json = jsonify(component.json, skip_specials=True)
|
| 230 |
+
|
| 231 |
+
st.subheader(f"*Agent Details*")
|
| 232 |
+
|
| 233 |
+
path = Select.for_app(query)
|
| 234 |
+
|
| 235 |
+
prompt_types = {
|
| 236 |
+
k: v for k, v in prompt_details_json.items() if (v is not None) and
|
| 237 |
+
not is_empty(v) and not is_noserio(v) and k != CLASS_INFO
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
for key, value in prompt_types.items():
|
| 241 |
+
with st.expander(key.capitalize() + " " +
|
| 242 |
+
render_selector_markdown(getattr(path, key)),
|
| 243 |
+
expanded=True):
|
| 244 |
+
|
| 245 |
+
if isinstance(value, (Dict, List)):
|
| 246 |
+
st.write(value)
|
| 247 |
+
else:
|
| 248 |
+
if isinstance(value, str) and len(value) > 32:
|
| 249 |
+
st.text(value)
|
| 250 |
+
else:
|
| 251 |
+
st.write(value)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def draw_tool_info(query: Lens, component: ComponentView) -> None:
|
| 255 |
+
# copy of draw_prompt_info
|
| 256 |
+
# TODO: dedup
|
| 257 |
+
prompt_details_json = jsonify(component.json, skip_specials=True)
|
| 258 |
+
|
| 259 |
+
st.subheader(f"*Tool Details*")
|
| 260 |
+
|
| 261 |
+
path = Select.for_app(query)
|
| 262 |
+
|
| 263 |
+
prompt_types = {
|
| 264 |
+
k: v for k, v in prompt_details_json.items() if (v is not None) and
|
| 265 |
+
not is_empty(v) and not is_noserio(v) and k != CLASS_INFO
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
for key, value in prompt_types.items():
|
| 269 |
+
with st.expander(key.capitalize() + " " +
|
| 270 |
+
render_selector_markdown(getattr(path, key)),
|
| 271 |
+
expanded=True):
|
| 272 |
+
|
| 273 |
+
if isinstance(value, (Dict, List)):
|
| 274 |
+
st.write(value)
|
| 275 |
+
else:
|
| 276 |
+
if isinstance(value, str) and len(value) > 32:
|
| 277 |
+
st.text(value)
|
| 278 |
+
else:
|
| 279 |
+
st.write(value)
|
ux/styles.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Callable, List, NamedTuple, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from trulens_eval.utils.serial import SerialModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ResultCategoryType(Enum):
|
| 12 |
+
PASS = 0
|
| 13 |
+
WARNING = 1
|
| 14 |
+
FAIL = 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CATEGORY:
|
| 18 |
+
"""
|
| 19 |
+
Feedback result categories for displaying purposes: pass, warning, fail, or
|
| 20 |
+
unknown.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
class Category(SerialModel):
|
| 24 |
+
name: str
|
| 25 |
+
adjective: str
|
| 26 |
+
threshold: float
|
| 27 |
+
color: str
|
| 28 |
+
icon: str
|
| 29 |
+
direction: Optional[str] = None
|
| 30 |
+
compare: Optional[Callable[[float, float], bool]] = None
|
| 31 |
+
|
| 32 |
+
class FeedbackDirection(NamedTuple):
|
| 33 |
+
name: str
|
| 34 |
+
ascending: bool
|
| 35 |
+
thresholds: List[float]
|
| 36 |
+
|
| 37 |
+
# support both directions by default
|
| 38 |
+
# TODO: make this configurable (per feedback definition & per app?)
|
| 39 |
+
directions = [
|
| 40 |
+
FeedbackDirection("HIGHER_IS_BETTER", False, [0, 0.6, 0.8]),
|
| 41 |
+
FeedbackDirection("LOWER_IS_BETTER", True, [0.2, 0.4, 1]),
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
styling = {
|
| 45 |
+
"PASS": dict(color="#aaffaa", icon="✅"),
|
| 46 |
+
"WARNING": dict(color="#ffffaa", icon="⚠️"),
|
| 47 |
+
"FAIL": dict(color="#ffaaaa", icon="🛑"),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
for category_name in ResultCategoryType._member_names_:
|
| 51 |
+
locals()[category_name] = defaultdict(dict)
|
| 52 |
+
|
| 53 |
+
for direction in directions:
|
| 54 |
+
a = sorted(
|
| 55 |
+
zip(["low", "medium", "high"], sorted(direction.thresholds)),
|
| 56 |
+
key=operator.itemgetter(1),
|
| 57 |
+
reverse=not direction.ascending,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
for enum, (adjective, threshold) in enumerate(a):
|
| 61 |
+
category_name = ResultCategoryType(enum).name
|
| 62 |
+
locals()[category_name][direction.name] = Category(
|
| 63 |
+
name=category_name.lower(),
|
| 64 |
+
adjective=adjective,
|
| 65 |
+
threshold=threshold,
|
| 66 |
+
direction=direction.name,
|
| 67 |
+
compare=operator.ge
|
| 68 |
+
if direction.name == "HIGHER_IS_BETTER" else operator.le,
|
| 69 |
+
**styling[category_name],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
UNKNOWN = Category(
|
| 73 |
+
name="unknown",
|
| 74 |
+
adjective="unknown",
|
| 75 |
+
threshold=np.nan,
|
| 76 |
+
color="#aaaaaa",
|
| 77 |
+
icon="?"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# order matters here because `of_score` returns the first best category
|
| 81 |
+
ALL = [PASS, WARNING, FAIL] # not including UNKNOWN intentionally
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def of_score(score: float, higher_is_better: bool = True) -> Category:
|
| 85 |
+
direction_key = "HIGHER_IS_BETTER" if higher_is_better else "LOWER_IS_BETTER"
|
| 86 |
+
|
| 87 |
+
for cat in map(operator.itemgetter(direction_key), CATEGORY.ALL):
|
| 88 |
+
if cat.compare(score, cat.threshold):
|
| 89 |
+
return cat
|
| 90 |
+
|
| 91 |
+
return CATEGORY.UNKNOWN
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
default_direction = "HIGHER_IS_BETTER"
|
| 95 |
+
|
| 96 |
+
# These would be useful to include in our pages but don't yet see a way to do
|
| 97 |
+
# this in streamlit.
|
| 98 |
+
root_js = f"""
|
| 99 |
+
var default_pass_threshold = {CATEGORY.PASS[default_direction].threshold};
|
| 100 |
+
var default_warning_threshold = {CATEGORY.WARNING[default_direction].threshold};
|
| 101 |
+
var default_fail_threshold = {CATEGORY.FAIL[default_direction].threshold};
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
# Not presently used. Need to figure out how to include this in streamlit pages.
|
| 105 |
+
root_html = f"""
|
| 106 |
+
<script>
|
| 107 |
+
{root_js}
|
| 108 |
+
</script>
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
stmetricdelta_hidearrow = """
|
| 112 |
+
<style> [data-testid="stMetricDelta"] svg { display: none; } </style>
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
valid_directions = ["HIGHER_IS_BETTER", "LOWER_IS_BETTER"]
|
| 116 |
+
|
| 117 |
+
cellstyle_jscode = {
|
| 118 |
+
k: f"""function(params) {{
|
| 119 |
+
let v = parseFloat(params.value);
|
| 120 |
+
""" + "\n".join(
|
| 121 |
+
f"""
|
| 122 |
+
if (v {'>=' if k == "HIGHER_IS_BETTER" else '<='} {cat.threshold}) {{
|
| 123 |
+
return {{
|
| 124 |
+
'color': 'black',
|
| 125 |
+
'backgroundColor': '{cat.color}'
|
| 126 |
+
}};
|
| 127 |
+
}}
|
| 128 |
+
""" for cat in map(operator.itemgetter(k), CATEGORY.ALL)
|
| 129 |
+
) + f"""
|
| 130 |
+
// i.e. not a number
|
| 131 |
+
return {{
|
| 132 |
+
'color': 'black',
|
| 133 |
+
'backgroundColor': '{CATEGORY.UNKNOWN.color}'
|
| 134 |
+
}};
|
| 135 |
+
}}""" for k in valid_directions
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
hide_table_row_index = """
|
| 139 |
+
<style>
|
| 140 |
+
thead tr th:first-child {display:none}
|
| 141 |
+
tbody th {display:none}
|
| 142 |
+
</style>
|
| 143 |
+
"""
|
ux/trulens_logo.svg
ADDED
|
|
vision_api.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os, base64, requests
|
| 3 |
+
|
| 4 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 5 |
+
|
| 6 |
+
def get_transcribed_text(base64_image):
|
| 7 |
+
|
| 8 |
+
headers = {
|
| 9 |
+
"Content-Type": "application/json",
|
| 10 |
+
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
payload = {
|
| 14 |
+
"model": "gpt-4-vision-preview",
|
| 15 |
+
"messages": [
|
| 16 |
+
{
|
| 17 |
+
"role": "user",
|
| 18 |
+
"content": [
|
| 19 |
+
{
|
| 20 |
+
"type": "text",
|
| 21 |
+
"text": "transcribe the image into text for me."
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"type": "image_url",
|
| 25 |
+
"image_url": {
|
| 26 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
],
|
| 32 |
+
"max_tokens": 300
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
| 36 |
+
transcribed_msg = response.json()["choices"][0]["message"]["content"]
|
| 37 |
+
|
| 38 |
+
return transcribed_msg
|