Spaces:
Runtime error
Runtime error
Commit
·
039aebb
0
Parent(s):
Duplicate from deepset/wikipedia-assistant
Browse filesCo-authored-by: Danijel Petkovic <danijelpetkovic@users.noreply.huggingface.co>
- .gitattributes +27 -0
- .gitignore +154 -0
- .streamlit/config.toml +2 -0
- README.md +39 -0
- app.py +49 -0
- context_server/Dockerfile +23 -0
- context_server/__init__.py +0 -0
- context_server/main.py +122 -0
- context_server/requirements.txt +7 -0
- lfqa.png +0 -0
- lfqa_server/Dockerfile +19 -0
- lfqa_server/__init__.py +0 -0
- lfqa_server/main.py +130 -0
- lfqa_server/requirements.txt +7 -0
- multipage.py +65 -0
- pages/ask.py +379 -0
- pages/info.py +92 -0
- pages/settings.py +95 -0
- requirements-dev.txt +4 -0
- requirements.txt +11 -0
- style.css +178 -0
- training/run_retriever_no_trainer.py +381 -0
- training/run_retriever_no_trainer_gpl.py +403 -0
- training/run_seq2seq_no_trainer.py +446 -0
- util/common.py +120 -0
- util/create_dpr_training_from_dataset.py +103 -0
- util/create_dpr_training_from_faiss.py +144 -0
- util/create_faiss_index.py +67 -0
- util/eval_generate.py +140 -0
- util/kilt_create_dpr_support_docs.py +109 -0
- util/query_smoke_test.py +81 -0
- wikipedia_answer.png +0 -0
- wikipedia_context.png +0 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
105 |
+
__pypackages__/
|
106 |
+
|
107 |
+
# Celery stuff
|
108 |
+
celerybeat-schedule
|
109 |
+
celerybeat.pid
|
110 |
+
|
111 |
+
# SageMath parsed files
|
112 |
+
*.sage.py
|
113 |
+
|
114 |
+
# Environments
|
115 |
+
.env
|
116 |
+
.venv
|
117 |
+
env/
|
118 |
+
venv/
|
119 |
+
ENV/
|
120 |
+
env.bak/
|
121 |
+
venv.bak/
|
122 |
+
|
123 |
+
# Spyder project settings
|
124 |
+
.spyderproject
|
125 |
+
.spyproject
|
126 |
+
|
127 |
+
# Rope project settings
|
128 |
+
.ropeproject
|
129 |
+
|
130 |
+
# mkdocs documentation
|
131 |
+
/site
|
132 |
+
|
133 |
+
# mypy
|
134 |
+
.mypy_cache/
|
135 |
+
.dmypy.json
|
136 |
+
dmypy.json
|
137 |
+
|
138 |
+
# Pyre type checker
|
139 |
+
.pyre/
|
140 |
+
|
141 |
+
# pytype static type analyzer
|
142 |
+
.pytype/
|
143 |
+
|
144 |
+
# Cython debug symbols
|
145 |
+
cython_debug/
|
146 |
+
|
147 |
+
# PyCharm
|
148 |
+
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
|
149 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
150 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
151 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
152 |
+
.idea/
|
153 |
+
out.flac
|
154 |
+
out.mp3
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[logger]
|
2 |
+
level = "debug"
|
README.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Wikipedia Assistant
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.9.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: deepset/wikipedia-assistant
|
11 |
+
---
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
|
15 |
+
`title`: _string_
|
16 |
+
Display title for the Space
|
17 |
+
|
18 |
+
`emoji`: _string_
|
19 |
+
Space emoji (emoji-only character allowed)
|
20 |
+
|
21 |
+
`colorFrom`: _string_
|
22 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
23 |
+
|
24 |
+
`colorTo`: _string_
|
25 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
26 |
+
|
27 |
+
`sdk`: `streamlit`
|
28 |
+
Can be either `gradio` or `streamlit`
|
29 |
+
|
30 |
+
`sdk_version` : `1.9.0`
|
31 |
+
Only applicable for `streamlit` SDK.
|
32 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
33 |
+
|
34 |
+
`app_file`: _string_
|
35 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
36 |
+
Path is relative to the root of the repository.
|
37 |
+
|
38 |
+
`pinned`: _boolean_
|
39 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
import streamlit as st
|
4 |
+
from multipage import MultiPage
|
5 |
+
from pages import ask, settings, info
|
6 |
+
|
7 |
+
|
8 |
+
logging.basicConfig(
|
9 |
+
level=logging.DEBUG,
|
10 |
+
format="%(levelname)s %(asctime)s %(name)s:%(message)s",
|
11 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
12 |
+
force=True,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def init_session_key_value(key, value):
|
17 |
+
if key not in st.session_state:
|
18 |
+
st.session_state[key] = value
|
19 |
+
|
20 |
+
|
21 |
+
lfqa_api = "HuggingFace" if "api_lfqa_selector" not in st.secrets else st.secrets["api_lfqa_selector"]
|
22 |
+
session_values = {"api_lfqa_selector": lfqa_api,
|
23 |
+
"tts": "Google",
|
24 |
+
"min_length": 64,
|
25 |
+
"max_length": 256,
|
26 |
+
"do_sample": False,
|
27 |
+
"early_stopping": True,
|
28 |
+
"num_beams": 8,
|
29 |
+
"temperature": 1.0,
|
30 |
+
"top_k": None,
|
31 |
+
"top_p": None,
|
32 |
+
"no_repeat_ngram_size": 3,
|
33 |
+
"num_return_sequences": 1}
|
34 |
+
|
35 |
+
for k, v in session_values.items():
|
36 |
+
init_session_key_value(k, v)
|
37 |
+
|
38 |
+
app = MultiPage()
|
39 |
+
st.set_page_config(
|
40 |
+
page_title="Wikipedia Assistant",
|
41 |
+
initial_sidebar_state="expanded",
|
42 |
+
)
|
43 |
+
# Add all your application here
|
44 |
+
app.add_page("Home", "house", ask.app)
|
45 |
+
app.add_page("Settings", "gear", settings.app)
|
46 |
+
app.add_page("Info", "info", info.app)
|
47 |
+
|
48 |
+
# The main app
|
49 |
+
app.run()
|
context_server/Dockerfile
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
|
2 |
+
#set up environment
|
3 |
+
RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
|
4 |
+
RUN apt-get install unzip
|
5 |
+
RUN apt-get -y install python3
|
6 |
+
RUN apt-get -y install python3-pip
|
7 |
+
|
8 |
+
WORKDIR /code
|
9 |
+
|
10 |
+
ENV HF_HOME=/code/cache
|
11 |
+
|
12 |
+
COPY ./requirements.txt /code/requirements.txt
|
13 |
+
|
14 |
+
RUN pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
|
15 |
+
RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
|
16 |
+
|
17 |
+
COPY ./main.py /code/app/main.py
|
18 |
+
|
19 |
+
COPY ./data/kilt_wiki_prepared/ /code/data/kilt_wiki_prepared
|
20 |
+
|
21 |
+
COPY ./data/kilt_wikipedia.faiss /code/data/kilt_wikipedia.faiss
|
22 |
+
|
23 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
|
context_server/__init__.py
ADDED
File without changes
|
context_server/main.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from fastapi import FastAPI, Depends, status
|
3 |
+
from fastapi.responses import PlainTextResponse
|
4 |
+
from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder
|
5 |
+
|
6 |
+
from datasets import load_from_disk
|
7 |
+
import time
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
import jwt
|
11 |
+
from decouple import config
|
12 |
+
from fastapi import Request, HTTPException
|
13 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
14 |
+
|
15 |
+
JWT_SECRET = config("secret")
|
16 |
+
JWT_ALGORITHM = config("algorithm")
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
app.ready = False
|
20 |
+
columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
21 |
+
'wikidata_info', 'history']
|
22 |
+
|
23 |
+
min_snippet_length = 20
|
24 |
+
topk = 21
|
25 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
model = DPRQuestionEncoder.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki").to(device)
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki")
|
28 |
+
_ = model.eval()
|
29 |
+
|
30 |
+
index_file_name = "./data/kilt_wikipedia.faiss"
|
31 |
+
|
32 |
+
kilt_wikipedia_paragraphs = load_from_disk("./data/kilt_wiki_prepared")
|
33 |
+
# use paragraphs that are not simple fragments or very short sentences
|
34 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 200)
|
35 |
+
|
36 |
+
|
37 |
+
class JWTBearer(HTTPBearer):
|
38 |
+
def __init__(self, auto_error: bool = True):
|
39 |
+
super(JWTBearer, self).__init__(auto_error=auto_error)
|
40 |
+
|
41 |
+
async def __call__(self, request: Request):
|
42 |
+
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
|
43 |
+
if credentials:
|
44 |
+
if not credentials.scheme == "Bearer":
|
45 |
+
raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
|
46 |
+
if not self.verify_jwt(credentials.credentials):
|
47 |
+
raise HTTPException(status_code=403, detail="Invalid token or expired token.")
|
48 |
+
return credentials.credentials
|
49 |
+
else:
|
50 |
+
raise HTTPException(status_code=403, detail="Invalid authorization code.")
|
51 |
+
|
52 |
+
def verify_jwt(self, jwtoken: str) -> bool:
|
53 |
+
isTokenValid: bool = False
|
54 |
+
|
55 |
+
try:
|
56 |
+
payload = decodeJWT(jwtoken)
|
57 |
+
except:
|
58 |
+
payload = None
|
59 |
+
if payload:
|
60 |
+
isTokenValid = True
|
61 |
+
return isTokenValid
|
62 |
+
|
63 |
+
|
64 |
+
def token_response(token: str):
|
65 |
+
return {
|
66 |
+
"access_token": token
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
def signJWT(user_id: str) -> Dict[str, str]:
|
71 |
+
payload = {
|
72 |
+
"user_id": user_id,
|
73 |
+
"expires": time.time() + 6000
|
74 |
+
}
|
75 |
+
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
76 |
+
|
77 |
+
return token_response(token)
|
78 |
+
|
79 |
+
|
80 |
+
def decodeJWT(token: str) -> dict:
|
81 |
+
try:
|
82 |
+
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
83 |
+
return decoded_token if decoded_token["expires"] >= time.time() else None
|
84 |
+
except:
|
85 |
+
return {}
|
86 |
+
|
87 |
+
|
88 |
+
def embed_questions_for_retrieval(questions):
|
89 |
+
query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
|
90 |
+
with torch.no_grad():
|
91 |
+
q_reps = model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output
|
92 |
+
return q_reps.cpu().numpy()
|
93 |
+
|
94 |
+
def query_index(question):
|
95 |
+
question_embedding = embed_questions_for_retrieval([question])
|
96 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
97 |
+
columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id',
|
98 |
+
'start_character', 'end_character']
|
99 |
+
retrieved_examples = []
|
100 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
101 |
+
for i in range(topk):
|
102 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
103 |
+
return retrieved_examples
|
104 |
+
|
105 |
+
|
106 |
+
@app.on_event("startup")
|
107 |
+
def startup():
|
108 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
|
109 |
+
app.ready = True
|
110 |
+
|
111 |
+
|
112 |
+
@app.get("/healthz")
|
113 |
+
def healthz():
|
114 |
+
if app.ready:
|
115 |
+
return PlainTextResponse("ok")
|
116 |
+
return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
|
117 |
+
|
118 |
+
|
119 |
+
@app.get("/find_context", dependencies=[Depends(JWTBearer())])
|
120 |
+
def find_context(question: str = None):
|
121 |
+
return [res for res in query_index(question) if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
|
122 |
+
|
context_server/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
transformers
|
3 |
+
fastapi
|
4 |
+
faiss-gpu
|
5 |
+
uvicorn[standard]
|
6 |
+
PyJWT==1.7.1
|
7 |
+
python-decouple==3.3
|
lfqa.png
ADDED
![]() |
lfqa_server/Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
|
2 |
+
#set up environment
|
3 |
+
RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
|
4 |
+
RUN apt-get install unzip
|
5 |
+
RUN apt-get -y install python3
|
6 |
+
RUN apt-get -y install python3-pip
|
7 |
+
|
8 |
+
WORKDIR /code
|
9 |
+
|
10 |
+
ENV HF_HOME=/code/cache
|
11 |
+
|
12 |
+
COPY ./requirements.txt /code/requirements.txt
|
13 |
+
|
14 |
+
RUN pip3 install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
|
15 |
+
RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
|
16 |
+
|
17 |
+
COPY ./main.py /code/app/main.py
|
18 |
+
|
19 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
|
lfqa_server/__init__.py
ADDED
File without changes
|
lfqa_server/main.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from fastapi import FastAPI, Depends, status
|
3 |
+
from fastapi.responses import PlainTextResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
+
|
7 |
+
import time
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
|
10 |
+
import jwt
|
11 |
+
from decouple import config
|
12 |
+
from fastapi import Request, HTTPException
|
13 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
14 |
+
|
15 |
+
JWT_SECRET = config("secret")
|
16 |
+
JWT_ALGORITHM = config("algorithm")
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
app.ready = False
|
20 |
+
|
21 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
|
23 |
+
model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
|
24 |
+
_ = model.eval()
|
25 |
+
|
26 |
+
|
27 |
+
class JWTBearer(HTTPBearer):
|
28 |
+
def __init__(self, auto_error: bool = True):
|
29 |
+
super(JWTBearer, self).__init__(auto_error=auto_error)
|
30 |
+
|
31 |
+
async def __call__(self, request: Request):
|
32 |
+
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
|
33 |
+
if credentials:
|
34 |
+
if not credentials.scheme == "Bearer":
|
35 |
+
raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
|
36 |
+
if not self.verify_jwt(credentials.credentials):
|
37 |
+
raise HTTPException(status_code=403, detail="Invalid token or expired token.")
|
38 |
+
return credentials.credentials
|
39 |
+
else:
|
40 |
+
raise HTTPException(status_code=403, detail="Invalid authorization code.")
|
41 |
+
|
42 |
+
def verify_jwt(self, jwtoken: str) -> bool:
|
43 |
+
isTokenValid: bool = False
|
44 |
+
|
45 |
+
try:
|
46 |
+
payload = decodeJWT(jwtoken)
|
47 |
+
except:
|
48 |
+
payload = None
|
49 |
+
if payload:
|
50 |
+
isTokenValid = True
|
51 |
+
return isTokenValid
|
52 |
+
|
53 |
+
|
54 |
+
def token_response(token: str):
|
55 |
+
return {
|
56 |
+
"access_token": token
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
def signJWT(user_id: str) -> Dict[str, str]:
|
61 |
+
payload = {
|
62 |
+
"user_id": user_id,
|
63 |
+
"expires": time.time() + 6000
|
64 |
+
}
|
65 |
+
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
66 |
+
|
67 |
+
return token_response(token)
|
68 |
+
|
69 |
+
|
70 |
+
def decodeJWT(token: str) -> dict:
|
71 |
+
try:
|
72 |
+
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
73 |
+
return decoded_token if decoded_token["expires"] >= time.time() else None
|
74 |
+
except:
|
75 |
+
return {}
|
76 |
+
|
77 |
+
|
78 |
+
class LFQAParameters(BaseModel):
|
79 |
+
min_length: int = 50
|
80 |
+
max_length: int = 250
|
81 |
+
do_sample: bool = False
|
82 |
+
early_stopping: bool = True
|
83 |
+
num_beams: int = 8
|
84 |
+
temperature: float = 1.0
|
85 |
+
top_k: float = None
|
86 |
+
top_p: float = None
|
87 |
+
no_repeat_ngram_size: int = 3
|
88 |
+
num_return_sequences: int = 1
|
89 |
+
|
90 |
+
|
91 |
+
class InferencePayload(BaseModel):
|
92 |
+
model_input: str
|
93 |
+
parameters: Optional[LFQAParameters] = LFQAParameters()
|
94 |
+
|
95 |
+
|
96 |
+
@app.on_event("startup")
|
97 |
+
def startup():
|
98 |
+
app.ready = True
|
99 |
+
|
100 |
+
|
101 |
+
@app.get("/healthz")
|
102 |
+
def healthz():
|
103 |
+
if app.ready:
|
104 |
+
return PlainTextResponse("ok")
|
105 |
+
return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
|
106 |
+
|
107 |
+
|
108 |
+
@app.post("/generate/", dependencies=[Depends(JWTBearer())])
|
109 |
+
def generate(context: InferencePayload):
|
110 |
+
|
111 |
+
model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
|
112 |
+
param = context.parameters
|
113 |
+
generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
|
114 |
+
attention_mask=model_input["attention_mask"].to(device),
|
115 |
+
min_length=param.min_length,
|
116 |
+
max_length=param.max_length,
|
117 |
+
do_sample=param.do_sample,
|
118 |
+
early_stopping=param.early_stopping,
|
119 |
+
num_beams=param.num_beams,
|
120 |
+
temperature=param.temperature,
|
121 |
+
top_k=param.top_k,
|
122 |
+
top_p=param.top_p,
|
123 |
+
no_repeat_ngram_size=param.no_repeat_ngram_size,
|
124 |
+
num_return_sequences=param.num_return_sequences)
|
125 |
+
answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
|
126 |
+
clean_up_tokenization_spaces=True)
|
127 |
+
results = []
|
128 |
+
for answer in answers:
|
129 |
+
results.append({"generated_text": answer})
|
130 |
+
return results
|
lfqa_server/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
transformers
|
3 |
+
fastapi
|
4 |
+
faiss-gpu
|
5 |
+
uvicorn[standard]
|
6 |
+
PyJWT==1.7.1
|
7 |
+
python-decouple==3.3
|
multipage.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is the framework for generating multiple Streamlit applications
|
3 |
+
through an object oriented framework.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Import necessary libraries
|
7 |
+
import streamlit as st
|
8 |
+
from streamlit_option_menu import option_menu
|
9 |
+
|
10 |
+
|
11 |
+
# Define the multipage class to manage the multiple apps in our program
|
12 |
+
class MultiPage:
|
13 |
+
"""Framework for combining multiple streamlit applications."""
|
14 |
+
|
15 |
+
def __init__(self) -> None:
|
16 |
+
"""Constructor class to generate a list which will store all our applications as an instance variable."""
|
17 |
+
self.pages = []
|
18 |
+
|
19 |
+
def add_page(self, title, icon, func) -> None:
|
20 |
+
"""Class Method to Add pages to the project
|
21 |
+
|
22 |
+
Args:
|
23 |
+
title ([str]): The title of page which we are adding to the list of apps
|
24 |
+
|
25 |
+
func: Python function to render this page in Streamlit
|
26 |
+
"""
|
27 |
+
|
28 |
+
self.pages.append(
|
29 |
+
{
|
30 |
+
"title": title,
|
31 |
+
"icon": icon,
|
32 |
+
"function": func
|
33 |
+
}
|
34 |
+
)
|
35 |
+
|
36 |
+
def run(self):
|
37 |
+
# Drodown to select the page to run
|
38 |
+
st.markdown("""
|
39 |
+
<style>
|
40 |
+
section[data-testid="stSidebar"] > div:first-of-type {
|
41 |
+
background-color: var(--secondary-background-color);
|
42 |
+
background: var(--secondary-background-color);
|
43 |
+
width: 250px;
|
44 |
+
padding: 4rem 0;
|
45 |
+
box-shadow: -2rem 0px 2rem 2rem rgba(0,0,0,0.16);
|
46 |
+
}
|
47 |
+
section[aria-expanded="true"] > div:nth-of-type(2) {
|
48 |
+
display: none;
|
49 |
+
}
|
50 |
+
.main > div:first-of-type {
|
51 |
+
padding: 1rem 0;
|
52 |
+
}
|
53 |
+
</style>
|
54 |
+
""", unsafe_allow_html=True)
|
55 |
+
|
56 |
+
with st.sidebar:
|
57 |
+
selected = option_menu(None, [page["title"] for page in self.pages],
|
58 |
+
icons=[page["icon"] for page in self.pages],
|
59 |
+
menu_icon="cast", default_index=0)
|
60 |
+
|
61 |
+
# Run the selected page
|
62 |
+
for index, item in enumerate(self.pages):
|
63 |
+
if item["title"] == selected:
|
64 |
+
self.pages[index]["function"]()
|
65 |
+
break
|
pages/ask.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import colorsys
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
|
7 |
+
import nltk
|
8 |
+
import numpy as np
|
9 |
+
from nltk import tokenize
|
10 |
+
|
11 |
+
nltk.download('punkt')
|
12 |
+
from google.oauth2 import service_account
|
13 |
+
from google.cloud import texttospeech
|
14 |
+
|
15 |
+
from typing import Dict, Optional, List
|
16 |
+
|
17 |
+
import jwt
|
18 |
+
import requests
|
19 |
+
import streamlit as st
|
20 |
+
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
21 |
+
|
22 |
+
JWT_SECRET = st.secrets["api_secret"]
|
23 |
+
JWT_ALGORITHM = st.secrets["api_algorithm"]
|
24 |
+
INFERENCE_TOKEN = st.secrets["api_inference"]
|
25 |
+
CONTEXT_API_URL = st.secrets["api_context"]
|
26 |
+
LFQA_API_URL = st.secrets["api_lfqa"]
|
27 |
+
|
28 |
+
headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"}
|
29 |
+
API_URL = "https://api-inference.huggingface.co/models/askainet/bart_lfqa"
|
30 |
+
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan"
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def api_inference_lfqa(model_input: str):
|
36 |
+
payload = {
|
37 |
+
"inputs": model_input,
|
38 |
+
"parameters": {
|
39 |
+
"truncation": "longest_first",
|
40 |
+
"min_length": st.session_state["min_length"],
|
41 |
+
"max_length": st.session_state["max_length"],
|
42 |
+
"do_sample": st.session_state["do_sample"],
|
43 |
+
"early_stopping": st.session_state["early_stopping"],
|
44 |
+
"num_beams": st.session_state["num_beams"],
|
45 |
+
"temperature": st.session_state["temperature"],
|
46 |
+
"top_k": None,
|
47 |
+
"top_p": None,
|
48 |
+
"no_repeat_ngram_size": 3,
|
49 |
+
"num_return_sequences": 1
|
50 |
+
},
|
51 |
+
"options": {
|
52 |
+
"wait_for_model": True
|
53 |
+
}
|
54 |
+
}
|
55 |
+
data = json.dumps(payload)
|
56 |
+
logger.debug(data)
|
57 |
+
response = requests.request("POST", API_URL, headers=headers, data=data)
|
58 |
+
return json.loads(response.content.decode("utf-8"))
|
59 |
+
|
60 |
+
|
61 |
+
def inference_lfqa(model_input: str, header: dict):
|
62 |
+
payload = {
|
63 |
+
"model_input": model_input,
|
64 |
+
"parameters": {
|
65 |
+
"min_length": st.session_state["min_length"],
|
66 |
+
"max_length": st.session_state["max_length"],
|
67 |
+
"do_sample": st.session_state["do_sample"],
|
68 |
+
"early_stopping": st.session_state["early_stopping"],
|
69 |
+
"num_beams": st.session_state["num_beams"],
|
70 |
+
"temperature": st.session_state["temperature"],
|
71 |
+
"top_k": None,
|
72 |
+
"top_p": None,
|
73 |
+
"no_repeat_ngram_size": 3,
|
74 |
+
"num_return_sequences": 1
|
75 |
+
}
|
76 |
+
}
|
77 |
+
data = json.dumps(payload)
|
78 |
+
try:
|
79 |
+
response = requests.request("POST", LFQA_API_URL, headers=header, data=data)
|
80 |
+
if response.status_code == 200:
|
81 |
+
json_response = response.content.decode("utf-8")
|
82 |
+
result = json.loads(json_response)
|
83 |
+
else:
|
84 |
+
result = {"error": f"LFQA service unavailable, status code={response.status_code}"}
|
85 |
+
except requests.exceptions.RequestException as e:
|
86 |
+
result = {"error": e}
|
87 |
+
return result
|
88 |
+
|
89 |
+
|
90 |
+
def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]):
|
91 |
+
if "HuggingFace" == service_backend:
|
92 |
+
inference_response = api_inference_lfqa(model_input)
|
93 |
+
else:
|
94 |
+
inference_response = inference_lfqa(model_input, header)
|
95 |
+
return inference_response
|
96 |
+
|
97 |
+
|
98 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
99 |
+
def hf_tts(text: str):
|
100 |
+
payload = {
|
101 |
+
"inputs": text,
|
102 |
+
"parameters": {
|
103 |
+
"vocoder_tag": "str_or_none(none)",
|
104 |
+
"threshold": 0.5,
|
105 |
+
"minlenratio": 0.0,
|
106 |
+
"maxlenratio": 10.0,
|
107 |
+
"use_att_constraint": False,
|
108 |
+
"backward_window": 1,
|
109 |
+
"forward_window": 3,
|
110 |
+
"speed_control_alpha": 1.0,
|
111 |
+
"noise_scale": 0.333,
|
112 |
+
"noise_scale_dur": 0.333
|
113 |
+
},
|
114 |
+
"options": {
|
115 |
+
"wait_for_model": True
|
116 |
+
}
|
117 |
+
}
|
118 |
+
data = json.dumps(payload)
|
119 |
+
response = requests.request("POST", API_URL_TTS, headers=headers, data=data)
|
120 |
+
return response.content
|
121 |
+
|
122 |
+
|
123 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
124 |
+
def google_tts(text: str, private_key_id: str, private_key: str, client_email: str):
|
125 |
+
config = {
|
126 |
+
"private_key_id": private_key_id,
|
127 |
+
"private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n",
|
128 |
+
"client_email": client_email,
|
129 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
130 |
+
}
|
131 |
+
credentials = service_account.Credentials.from_service_account_info(config)
|
132 |
+
client = texttospeech.TextToSpeechClient(credentials=credentials)
|
133 |
+
|
134 |
+
synthesis_input = texttospeech.SynthesisInput(text=text)
|
135 |
+
|
136 |
+
# Build the voice request, select the language code ("en-US") and the ssml
|
137 |
+
# voice gender ("neutral")
|
138 |
+
voice = texttospeech.VoiceSelectionParams(language_code="en-US",
|
139 |
+
ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL)
|
140 |
+
|
141 |
+
# Select the type of audio file you want returned
|
142 |
+
audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3)
|
143 |
+
|
144 |
+
# Perform the text-to-speech request on the text input with the selected
|
145 |
+
# voice parameters and audio file type
|
146 |
+
response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config)
|
147 |
+
return response
|
148 |
+
|
149 |
+
|
150 |
+
def request_context_passages(question, header):
|
151 |
+
try:
|
152 |
+
response = requests.request("GET", CONTEXT_API_URL + question, headers=header)
|
153 |
+
if response.status_code == 200:
|
154 |
+
json_response = response.content.decode("utf-8")
|
155 |
+
result = json.loads(json_response)
|
156 |
+
else:
|
157 |
+
result = {"error": f"Context passage service unavailable, status code={response.status_code}"}
|
158 |
+
except requests.exceptions.RequestException as e:
|
159 |
+
result = {"error": e}
|
160 |
+
|
161 |
+
return result
|
162 |
+
|
163 |
+
|
164 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
165 |
+
def get_sentence_transformer():
|
166 |
+
return SentenceTransformer('all-MiniLM-L6-v2')
|
167 |
+
|
168 |
+
|
169 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
170 |
+
def get_sentence_transformer_encoding(sentences):
|
171 |
+
model = get_sentence_transformer()
|
172 |
+
return model.encode([sentence for sentence in sentences], convert_to_tensor=True)
|
173 |
+
|
174 |
+
|
175 |
+
def sign_jwt() -> Dict[str, str]:
|
176 |
+
payload = {
|
177 |
+
"expires": time.time() + 6000
|
178 |
+
}
|
179 |
+
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
180 |
+
return token
|
181 |
+
|
182 |
+
|
183 |
+
def extract_sentences_from_passages(passages):
|
184 |
+
sentences = []
|
185 |
+
for idx, node in enumerate(passages):
|
186 |
+
sentences.extend(tokenize.sent_tokenize(node["text"]))
|
187 |
+
return sentences
|
188 |
+
|
189 |
+
|
190 |
+
def similarity_color_picker(similarity: float):
|
191 |
+
value = int(similarity * 75)
|
192 |
+
rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0)
|
193 |
+
return [round(255 * x) for x in rgb]
|
194 |
+
|
195 |
+
|
196 |
+
def rgb_to_hex(rgb):
|
197 |
+
return '%02x%02x%02x' % tuple(rgb)
|
198 |
+
|
199 |
+
|
200 |
+
def similiarity_to_hex(similarity: float):
|
201 |
+
return rgb_to_hex(similarity_color_picker(similarity))
|
202 |
+
|
203 |
+
|
204 |
+
def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]:
|
205 |
+
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
206 |
+
question_passage_combinations = [[question, p["text"]] for p in passages]
|
207 |
+
|
208 |
+
# Compute the similarity scores for these combinations
|
209 |
+
similarity_scores = ce.predict(question_passage_combinations)
|
210 |
+
|
211 |
+
# Sort the scores in decreasing order
|
212 |
+
sim_ranking_idx = np.flip(np.argsort(similarity_scores))
|
213 |
+
return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]]
|
214 |
+
|
215 |
+
|
216 |
+
def answer_to_context_similarity(generated_answer, context_passages, topk=3):
|
217 |
+
context_sentences = extract_sentences_from_passages(context_passages)
|
218 |
+
context_sentences_e = get_sentence_transformer_encoding(context_sentences)
|
219 |
+
answer_sentences = tokenize.sent_tokenize(generated_answer)
|
220 |
+
answer_sentences_e = get_sentence_transformer_encoding(answer_sentences)
|
221 |
+
search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk)
|
222 |
+
result = []
|
223 |
+
for idx, r in enumerate(search_result):
|
224 |
+
context = []
|
225 |
+
for idx_c in range(topk):
|
226 |
+
context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]})
|
227 |
+
result.append({"answer": answer_sentences[idx], "context": context})
|
228 |
+
return result
|
229 |
+
|
230 |
+
|
231 |
+
def post_process_answer(generated_answer):
|
232 |
+
result = generated_answer
|
233 |
+
# detect sentence boundaries regex pattern
|
234 |
+
regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))"
|
235 |
+
answer_sentences = tokenize.sent_tokenize(generated_answer)
|
236 |
+
# do we have truncated last sentence?
|
237 |
+
if len(answer_sentences) > len(re.findall(regex, generated_answer)):
|
238 |
+
drop_last_sentence = " ".join(s for s in answer_sentences[:-1])
|
239 |
+
result = drop_last_sentence
|
240 |
+
return result.strip()
|
241 |
+
|
242 |
+
|
243 |
+
def format_score(value: float, precision=2):
|
244 |
+
return f"{value:.{precision}f}"
|
245 |
+
|
246 |
+
|
247 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
248 |
+
def get_answer(question: str):
|
249 |
+
if not question:
|
250 |
+
return {}
|
251 |
+
|
252 |
+
resp: Dict[str, str] = {}
|
253 |
+
if question and len(question.split()) > 3:
|
254 |
+
header = {"Authorization": f"Bearer {sign_jwt()}"}
|
255 |
+
context_passages = request_context_passages(question, header)
|
256 |
+
if "error" in context_passages:
|
257 |
+
resp = context_passages
|
258 |
+
else:
|
259 |
+
context_passages = rerank(question, context_passages)
|
260 |
+
conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages])
|
261 |
+
model_input = f'question: {question} context: {conditioned_context}'
|
262 |
+
|
263 |
+
inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header)
|
264 |
+
if "error" in inference_response:
|
265 |
+
resp = inference_response
|
266 |
+
else:
|
267 |
+
resp["context_passages"] = context_passages
|
268 |
+
resp["answer"] = post_process_answer(inference_response[0]["generated_text"])
|
269 |
+
else:
|
270 |
+
resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."}
|
271 |
+
return resp
|
272 |
+
|
273 |
+
|
274 |
+
def app():
|
275 |
+
with open('style.css') as f:
|
276 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
277 |
+
footer = """
|
278 |
+
<div class="footer-custom">
|
279 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
280 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
281 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> |
|
282 |
+
<a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a>
|
283 |
+
</div>
|
284 |
+
"""
|
285 |
+
st.markdown(footer, unsafe_allow_html=True)
|
286 |
+
|
287 |
+
st.title('Wikipedia Assistant')
|
288 |
+
|
289 |
+
question = st.text_input(
|
290 |
+
label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"')
|
291 |
+
|
292 |
+
spinner = st.empty()
|
293 |
+
if question !="":
|
294 |
+
spinner.markdown(
|
295 |
+
f"""
|
296 |
+
<div class="loader-wrapper">
|
297 |
+
<div class="loader">
|
298 |
+
</div>
|
299 |
+
<p>Generating answer for: <b>{question}</b></p>
|
300 |
+
</div>
|
301 |
+
<label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label>
|
302 |
+
""",
|
303 |
+
unsafe_allow_html=True,
|
304 |
+
)
|
305 |
+
|
306 |
+
question_response = get_answer(question)
|
307 |
+
if question_response:
|
308 |
+
if "error" in question_response:
|
309 |
+
st.warning(question_response["error"])
|
310 |
+
else:
|
311 |
+
spinner.markdown(f"")
|
312 |
+
generated_answer = question_response["answer"]
|
313 |
+
context_passages = question_response["context_passages"]
|
314 |
+
sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3)
|
315 |
+
sentences = "<div class='sentence-wrapper'>"
|
316 |
+
for item in sentence_similarity:
|
317 |
+
sentences += '<span>'
|
318 |
+
score = item["context"][0]["score"]
|
319 |
+
support_sentence = item["context"][0]["source"]
|
320 |
+
sentences += "".join([
|
321 |
+
f' {item["answer"]}',
|
322 |
+
f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">',
|
323 |
+
f'{format_score(score, precision=1)}',
|
324 |
+
f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>'
|
325 |
+
])
|
326 |
+
sentences += '</span>'
|
327 |
+
sentences += '</span>'
|
328 |
+
st.markdown(sentences, unsafe_allow_html=True)
|
329 |
+
|
330 |
+
with st.spinner("Generating audio..."):
|
331 |
+
if st.session_state["tts"] == "HuggingFace":
|
332 |
+
audio_file = hf_tts(generated_answer)
|
333 |
+
with open("out.flac", "wb") as f:
|
334 |
+
f.write(audio_file)
|
335 |
+
else:
|
336 |
+
audio_file = google_tts(generated_answer, st.secrets["private_key_id"],
|
337 |
+
st.secrets["private_key"], st.secrets["client_email"])
|
338 |
+
with open("out.mp3", "wb") as f:
|
339 |
+
f.write(audio_file.audio_content)
|
340 |
+
|
341 |
+
audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3"
|
342 |
+
st.audio(audio_file)
|
343 |
+
|
344 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
345 |
+
|
346 |
+
model = get_sentence_transformer()
|
347 |
+
|
348 |
+
col1, col2 = st.columns(2)
|
349 |
+
|
350 |
+
with col1:
|
351 |
+
st.subheader("Context")
|
352 |
+
with col2:
|
353 |
+
selection = st.selectbox(
|
354 |
+
label="",
|
355 |
+
options=('Paragraphs', 'Sentences', 'Answer Similarity'),
|
356 |
+
help="Context represents Wikipedia passages used to generate the answer")
|
357 |
+
question_e = model.encode(question, convert_to_tensor=True)
|
358 |
+
if selection == "Paragraphs":
|
359 |
+
sentences = extract_sentences_from_passages(context_passages)
|
360 |
+
context_e = get_sentence_transformer_encoding(sentences)
|
361 |
+
scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e)
|
362 |
+
similarity_scores = scores[0].squeeze().tolist()
|
363 |
+
for idx, node in enumerate(context_passages):
|
364 |
+
node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx])
|
365 |
+
context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True)
|
366 |
+
st.json(context_passages)
|
367 |
+
elif selection == "Sentences":
|
368 |
+
sentences = extract_sentences_from_passages(context_passages)
|
369 |
+
sentences_e = get_sentence_transformer_encoding(sentences)
|
370 |
+
scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e)
|
371 |
+
sentence_similarity_scores = scores[0].squeeze().tolist()
|
372 |
+
result = []
|
373 |
+
for idx, sentence in enumerate(sentences):
|
374 |
+
result.append(
|
375 |
+
{"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])})
|
376 |
+
context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True))
|
377 |
+
st.json(context_sentences)
|
378 |
+
else:
|
379 |
+
st.json(sentence_similarity)
|
pages/info.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def app():
|
5 |
+
with open('style.css') as f:
|
6 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
7 |
+
footer = """
|
8 |
+
<div class="footer-custom">
|
9 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
10 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
11 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
|
12 |
+
</div>
|
13 |
+
"""
|
14 |
+
st.markdown(footer, unsafe_allow_html=True)
|
15 |
+
|
16 |
+
st.subheader("Intro")
|
17 |
+
intro = """
|
18 |
+
<div class="text">
|
19 |
+
Wikipedia Assistant is an example of a task usually referred to as the Long-Form Question Answering (LFQA).
|
20 |
+
These systems function by querying large document stores for relevant information and subsequently using
|
21 |
+
the retrieved documents to generate accurate, multi-sentence answers. The documents related to a given
|
22 |
+
query, colloquially called context passages, are not used merely as source tokens for extracted answers,
|
23 |
+
but instead provide a larger context for the synthesis of original, abstractive long-form answers.
|
24 |
+
LFQA systems usually consist of three components:
|
25 |
+
<ul>
|
26 |
+
<li>A document store including content passages for a variety of topics</li>
|
27 |
+
<li>Encoder models to encode documents/questions such that it is possible to query the document store</li>
|
28 |
+
<li>A Seq2Seq language model capable of generating paragraph-long answers when given a question and
|
29 |
+
context passages retrieved from the document store</li>
|
30 |
+
</ul>
|
31 |
+
</div>
|
32 |
+
<br>
|
33 |
+
"""
|
34 |
+
st.markdown(intro, unsafe_allow_html=True)
|
35 |
+
st.image("lfqa.png", caption="LFQA Architecture")
|
36 |
+
st.subheader("UI/UX")
|
37 |
+
st.write("Each sentence in the generated answer ends with a coloured tooltip; the colour ranges from red to green. "
|
38 |
+
"The tooltip contains a value representing answer sentence similarity to a specific sentence in the "
|
39 |
+
"Wikipedia context passages retrieved. Mouseover on the tooltip will show the sentence from the "
|
40 |
+
"Wikipedia context passage. If a sentence similarity is 1.0, the seq2seq model extracted and "
|
41 |
+
"copied the sentence verbatim from Wikipedia context passages. Lower values of sentence "
|
42 |
+
"similarity indicate the seq2seq model is struggling to generate a relevant sentence for the question "
|
43 |
+
"asked.")
|
44 |
+
st.image("wikipedia_answer.png", caption="Answer with similarity tooltips")
|
45 |
+
st.write("Below the generated answer are question-related Wikipedia context paragraphs (passages). One can view "
|
46 |
+
"these passages in a raw format retrieved using the 'Paragraphs' select menu option. The 'Sentences' menu "
|
47 |
+
"option shows the same paragraphs but on a sentence level. Finally, the 'Answer Similarity' menu option "
|
48 |
+
"shows the most similar three sentences from context paragraphs to each sentence in the generated answer.")
|
49 |
+
st.image("wikipedia_context.png", caption="Context paragraphs (passages)")
|
50 |
+
|
51 |
+
tts = """
|
52 |
+
<div class="text">
|
53 |
+
Wikipedia Assistant converts the text-based answer to speech via either Google text-to-speech engine or
|
54 |
+
<a href="https://github.com/espnet" target=_blank">Espnet model</a> hosted on
|
55 |
+
<a href="https://huggingface.co/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" target=_blank">
|
56 |
+
HuggingFace hub</a>
|
57 |
+
<br>
|
58 |
+
<br>
|
59 |
+
"""
|
60 |
+
st.markdown(tts, unsafe_allow_html=True)
|
61 |
+
|
62 |
+
st.subheader("Tips")
|
63 |
+
tips = """
|
64 |
+
<div class="text">
|
65 |
+
LFQA task is far from solved. Wikipedia Assistant will sometimes generate an answer unrelated to a question asked,
|
66 |
+
even downright wrong. However, if the question is elaborate and more specific, there is a decent chance of
|
67 |
+
getting a legible answer. LFQA systems are targeting ELI5 non-factoid type of questions. A general guideline
|
68 |
+
is - questions starting with why, what, and how are better suited than where and who questions. Be elaborate.
|
69 |
+
<br><br>
|
70 |
+
|
71 |
+
For example, to ask a science-based question, Wikipedia Assistant is better suited to answer the question: "Why do
|
72 |
+
airplane jet engines leave contrails in the sky?" than "Why do contrails exist?". Detailed and precise questions
|
73 |
+
are more likely to match the right half a dozen relevant passages in a 20+ GB Wikipedia dump to construct a good
|
74 |
+
answer.
|
75 |
+
</div>
|
76 |
+
<br>
|
77 |
+
"""
|
78 |
+
st.markdown(tips, unsafe_allow_html=True)
|
79 |
+
st.subheader("Technical details")
|
80 |
+
techinical_intro = """
|
81 |
+
<div class="text technical-details-info">
|
82 |
+
A question asked will be encoded with an <a href="https://huggingface.co/vblagoje/dpr-question_encoder-single-lfqa-wiki" target=_blank">encoder</a>
|
83 |
+
and sent to a server to find the most relevant Wikipedia passages. The Wikipedia <a href="https://huggingface.co/datasets/kilt_wikipedia" target=_blank">passages</a>
|
84 |
+
were previously encoded using a passage <a href="https://huggingface.co/vblagoje/dpr-ctx_encoder-single-lfqa-wiki" target=_blank">encoder</a> and
|
85 |
+
stored in the <a href="https://github.com/facebookresearch/faiss" target=_blank">Faiss</a> index. The question matching passages (a.k.a context passages) are retrieved from the Faiss
|
86 |
+
index and passed to a BART-based seq2seq <a href="https://huggingface.co/vblagoje/bart_lfqa" target=_blank">model</a> to
|
87 |
+
synthesize an original answer to the question.
|
88 |
+
|
89 |
+
</div>
|
90 |
+
"""
|
91 |
+
st.markdown(techinical_intro, unsafe_allow_html=True)
|
92 |
+
|
pages/settings.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
settings = {}
|
4 |
+
|
5 |
+
def app():
|
6 |
+
st.markdown("""
|
7 |
+
<style>
|
8 |
+
div[data-testid="stForm"] {
|
9 |
+
border: 0;
|
10 |
+
}
|
11 |
+
.footer-custom {
|
12 |
+
position: fixed;
|
13 |
+
bottom: 0;
|
14 |
+
width: 100%;
|
15 |
+
color: var(--text-color);
|
16 |
+
max-width: 698px;
|
17 |
+
font-size: 14px;
|
18 |
+
height: 50px;
|
19 |
+
padding: 10px 0;
|
20 |
+
z-index: 50;
|
21 |
+
}
|
22 |
+
footer {
|
23 |
+
display: none !important;
|
24 |
+
}
|
25 |
+
.footer-custom a {
|
26 |
+
color: var(--text-color);
|
27 |
+
}
|
28 |
+
button[kind="formSubmit"]{
|
29 |
+
margin-top: 40px;
|
30 |
+
border-radius: 20px;
|
31 |
+
padding: 5px 20px;
|
32 |
+
font-size: 18px;
|
33 |
+
background-color: var(--primary-color);
|
34 |
+
}
|
35 |
+
#lfqa-model-parameters {
|
36 |
+
margin-bottom: 50px;
|
37 |
+
font-size: 36px;
|
38 |
+
}
|
39 |
+
#tts-model-parameters {
|
40 |
+
font-size: 36px;
|
41 |
+
margin-top: 50px;
|
42 |
+
}
|
43 |
+
.stAlert {
|
44 |
+
width: 250px;
|
45 |
+
margin-top: 32px;
|
46 |
+
}
|
47 |
+
</style>
|
48 |
+
""", unsafe_allow_html=True)
|
49 |
+
|
50 |
+
with st.form("settings"):
|
51 |
+
footer = """
|
52 |
+
<div class="footer-custom">
|
53 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
54 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
55 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
|
56 |
+
</div>
|
57 |
+
"""
|
58 |
+
st.markdown(footer, unsafe_allow_html=True)
|
59 |
+
|
60 |
+
st.title("LFQA model parameters")
|
61 |
+
|
62 |
+
settings["min_length"] = st.slider("Min length", 20, 80, st.session_state["min_length"],
|
63 |
+
help="Min response length (words)")
|
64 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
65 |
+
settings["max_length"] = st.slider("Max length", 128, 320, st.session_state["max_length"],
|
66 |
+
help="Max response length (words)")
|
67 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
68 |
+
col1, col2 = st.columns(2)
|
69 |
+
with col1:
|
70 |
+
settings["do_sample"] = st.checkbox("Use sampling", st.session_state["do_sample"],
|
71 |
+
help="Whether or not to use sampling ; use greedy decoding otherwise.")
|
72 |
+
with col2:
|
73 |
+
settings["early_stopping"] = st.checkbox("Early stopping", st.session_state["early_stopping"],
|
74 |
+
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not.")
|
75 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
76 |
+
settings["num_beams"] = st.slider("Num beams", 1, 16, st.session_state["num_beams"],
|
77 |
+
help="Number of beams for beam search. 1 means no beam search.")
|
78 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
79 |
+
settings["temperature"] = st.slider("Temperature", 0.0, 1.0, st.session_state["temperature"], step=0.1,
|
80 |
+
help="The value used to module the next token probabilities")
|
81 |
+
|
82 |
+
st.title("TTS model parameters")
|
83 |
+
settings["tts"] = st.selectbox(label="Engine", options=("Google", "HuggingFace"),
|
84 |
+
index=["Google", "HuggingFace"].index(st.session_state["tts"]),
|
85 |
+
help="Answer text-to-speech engine")
|
86 |
+
|
87 |
+
# Every form must have a submit button.
|
88 |
+
col3, col4, col5, col6 = st.columns(4)
|
89 |
+
with col3:
|
90 |
+
submitted = st.form_submit_button("Save")
|
91 |
+
with col4:
|
92 |
+
if submitted:
|
93 |
+
for k, v in settings.items():
|
94 |
+
st.session_state[k] = v
|
95 |
+
st.success('App settings saved successfully.')
|
requirements-dev.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
datasets
|
3 |
+
transformers
|
4 |
+
sentence_transformers
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk
|
2 |
+
click==8.0.3
|
3 |
+
streamlit==1.9.0
|
4 |
+
sentence_transformers
|
5 |
+
requests
|
6 |
+
pyjwt
|
7 |
+
typing
|
8 |
+
streamlit_option_menu
|
9 |
+
google-auth
|
10 |
+
google-cloud-texttospeech
|
11 |
+
protobuf<=3.20.1
|
style.css
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.row-widget.stTextInput > div:first-of-type {
|
3 |
+
background: #fff;
|
4 |
+
display: flex;
|
5 |
+
border: 1px solid #dfe1e5;
|
6 |
+
box-shadow: none;
|
7 |
+
border-radius: 24px;
|
8 |
+
height: 50px;
|
9 |
+
width: auto;
|
10 |
+
margin: 10px auto 30px;
|
11 |
+
}
|
12 |
+
|
13 |
+
.row-widget.stTextInput > div:first-of-type:hover,
|
14 |
+
.row-widget.stTextInput > div:first-of-type:focus {
|
15 |
+
box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.2);
|
16 |
+
}
|
17 |
+
|
18 |
+
.row-widget.stTextInput .st-bq {
|
19 |
+
background-color: #fff;
|
20 |
+
}
|
21 |
+
|
22 |
+
.row-widget.stTextInput > label {
|
23 |
+
color: #b3b3b3;
|
24 |
+
}
|
25 |
+
|
26 |
+
.row-widget.stButton > button {
|
27 |
+
border-radius: 24px;
|
28 |
+
background-color: #B6C9B1;
|
29 |
+
color: #fff;
|
30 |
+
border: none;
|
31 |
+
padding: 6px 20px;
|
32 |
+
float: right;
|
33 |
+
background-image: none;
|
34 |
+
}
|
35 |
+
|
36 |
+
.row-widget.stButton > button:hover {
|
37 |
+
box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.2);
|
38 |
+
}
|
39 |
+
|
40 |
+
.row-widget.stButton > button:focus {
|
41 |
+
border: none;
|
42 |
+
color: #fff;
|
43 |
+
}
|
44 |
+
|
45 |
+
.footer-custom {
|
46 |
+
position: fixed;
|
47 |
+
bottom: 0;
|
48 |
+
width: 100%;
|
49 |
+
color: var(--text-color);
|
50 |
+
max-width: 698px;
|
51 |
+
font-size: 14px;
|
52 |
+
height: 50px;
|
53 |
+
padding: 10px 0;
|
54 |
+
z-index: 50;
|
55 |
+
}
|
56 |
+
|
57 |
+
.main {
|
58 |
+
padding: 20px;
|
59 |
+
}
|
60 |
+
|
61 |
+
footer {
|
62 |
+
display: none !important;
|
63 |
+
}
|
64 |
+
|
65 |
+
.footer-custom a {
|
66 |
+
color: var(--text-color);
|
67 |
+
}
|
68 |
+
|
69 |
+
#wikipedia-assistant {
|
70 |
+
font-size: 36px;
|
71 |
+
}
|
72 |
+
|
73 |
+
.generated-answer p {
|
74 |
+
font-size: 16px;
|
75 |
+
font-weight: bold;
|
76 |
+
}
|
77 |
+
|
78 |
+
.react-json-view {
|
79 |
+
margin: 40px 0 80px;
|
80 |
+
}
|
81 |
+
|
82 |
+
.tooltip {
|
83 |
+
text-align: center;
|
84 |
+
line-height: 20px;
|
85 |
+
display: table-caption;
|
86 |
+
font-size: 10px;
|
87 |
+
border-radius: 50%;
|
88 |
+
height: 20px;
|
89 |
+
width: 20px;
|
90 |
+
position: relative;
|
91 |
+
cursor: pointer;
|
92 |
+
color:#000;
|
93 |
+
}
|
94 |
+
|
95 |
+
.tooltip .tooltiptext {
|
96 |
+
visibility: hidden;
|
97 |
+
width: 280px;
|
98 |
+
text-align: center;
|
99 |
+
border-radius: 6px;
|
100 |
+
padding: 10px;
|
101 |
+
position: absolute;
|
102 |
+
z-index: 1;
|
103 |
+
top: 25px;
|
104 |
+
left: 50%;
|
105 |
+
margin-left: -140px;
|
106 |
+
font-size: 14px;
|
107 |
+
background-color: #fff;
|
108 |
+
border: 1px solid #ccc;
|
109 |
+
box-shadow: 0px 0px 3px 1px rgba(0, 0, 0, 0.16);
|
110 |
+
color: #000;
|
111 |
+
}
|
112 |
+
|
113 |
+
.tooltip:hover .tooltiptext {
|
114 |
+
visibility: visible;
|
115 |
+
}
|
116 |
+
|
117 |
+
.sentence-wrapper {
|
118 |
+
border-left: 4px solid #ffc423;
|
119 |
+
padding-left: 20px;
|
120 |
+
margin-bottom: 40px;
|
121 |
+
}
|
122 |
+
|
123 |
+
#context {
|
124 |
+
padding: 2rem 0 1rem;
|
125 |
+
}
|
126 |
+
|
127 |
+
hr {
|
128 |
+
margin: 2em 0 1em;
|
129 |
+
}
|
130 |
+
|
131 |
+
.technical-details-info {
|
132 |
+
margin-bottom: 100px;
|
133 |
+
}
|
134 |
+
|
135 |
+
.loader-wrapper {
|
136 |
+
display: flex;
|
137 |
+
align-items: center;
|
138 |
+
background-color: rgba(250, 202, 43, 0.2);
|
139 |
+
padding: 15px 20px;
|
140 |
+
border-radius: 6px;
|
141 |
+
}
|
142 |
+
|
143 |
+
.loader-wrapper p {
|
144 |
+
margin-bottom: 0;
|
145 |
+
margin-left: 20px;
|
146 |
+
}
|
147 |
+
|
148 |
+
.loader {
|
149 |
+
width: 30px;
|
150 |
+
height: 30px;
|
151 |
+
border: dotted 5px #868686;
|
152 |
+
border-radius: 100%;
|
153 |
+
animation: spin 1s linear infinite;
|
154 |
+
}
|
155 |
+
|
156 |
+
.loader-note {
|
157 |
+
font-size: 14px;
|
158 |
+
color: #b3b3b3;
|
159 |
+
margin-left: 5px;
|
160 |
+
}
|
161 |
+
|
162 |
+
@keyframes spin {
|
163 |
+
0% {
|
164 |
+
transform: rotate(0deg) scale(0.8);
|
165 |
+
border-top-color: transparent;
|
166 |
+
border-right-color: transparent;
|
167 |
+
}
|
168 |
+
50% { transform: rotate(180deg) scale(1.2);
|
169 |
+
border-color: #949494;
|
170 |
+
border-top-color: transparent;
|
171 |
+
border-right-color: transparent;
|
172 |
+
}
|
173 |
+
100% { transform: rotate(360deg) scale(0.8);
|
174 |
+
border-color: #bbbbbb;
|
175 |
+
border-top-color: transparent;
|
176 |
+
border-right-color: transparent;
|
177 |
+
}
|
178 |
+
}
|
training/run_retriever_no_trainer.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
from random import choice, randint
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from accelerate.utils import set_seed
|
10 |
+
from datasets import load_dataset
|
11 |
+
from torch.utils import checkpoint
|
12 |
+
from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSequenceClassification
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def get_parser():
|
20 |
+
parser = argparse.ArgumentParser(description="Train ELI5 retriever")
|
21 |
+
parser.add_argument(
|
22 |
+
"--dataset_name",
|
23 |
+
type=str,
|
24 |
+
default="vblagoje/lfqa",
|
25 |
+
help="The name of the dataset to use (via the datasets library).",
|
26 |
+
)
|
27 |
+
|
28 |
+
parser.add_argument(
|
29 |
+
"--per_device_train_batch_size",
|
30 |
+
type=int,
|
31 |
+
default=1024,
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--per_device_eval_batch_size",
|
36 |
+
type=int,
|
37 |
+
default=1024,
|
38 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
39 |
+
)
|
40 |
+
|
41 |
+
parser.add_argument(
|
42 |
+
"--max_length",
|
43 |
+
type=int,
|
44 |
+
default=128,
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--checkpoint_batch_size",
|
49 |
+
type=int,
|
50 |
+
default=32,
|
51 |
+
)
|
52 |
+
|
53 |
+
parser.add_argument(
|
54 |
+
"--pretrained_model_name",
|
55 |
+
type=str,
|
56 |
+
default="google/bert_uncased_L-8_H-768_A-12",
|
57 |
+
)
|
58 |
+
|
59 |
+
parser.add_argument(
|
60 |
+
"--model_save_name",
|
61 |
+
type=str,
|
62 |
+
default="eli5_retriever_model_l-12_h-768_b-512-512",
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"--learning_rate",
|
67 |
+
type=float,
|
68 |
+
default=2e-4,
|
69 |
+
)
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
"--weight_decay",
|
73 |
+
type=float,
|
74 |
+
default=0.2,
|
75 |
+
)
|
76 |
+
|
77 |
+
parser.add_argument(
|
78 |
+
"--log_freq",
|
79 |
+
type=int,
|
80 |
+
default=500,
|
81 |
+
help="Log train/validation loss every log_freq update steps"
|
82 |
+
)
|
83 |
+
|
84 |
+
parser.add_argument(
|
85 |
+
"--num_train_epochs",
|
86 |
+
type=int,
|
87 |
+
default=4,
|
88 |
+
)
|
89 |
+
|
90 |
+
parser.add_argument(
|
91 |
+
"--max_train_steps",
|
92 |
+
type=int,
|
93 |
+
default=None,
|
94 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
95 |
+
)
|
96 |
+
|
97 |
+
parser.add_argument(
|
98 |
+
"--gradient_accumulation_steps",
|
99 |
+
type=int,
|
100 |
+
default=1,
|
101 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
102 |
+
)
|
103 |
+
|
104 |
+
parser.add_argument(
|
105 |
+
"--lr_scheduler_type",
|
106 |
+
type=SchedulerType,
|
107 |
+
default="linear", # this is linear with warmup
|
108 |
+
help="The scheduler type to use.",
|
109 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
110 |
+
)
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--num_warmup_steps",
|
114 |
+
type=int,
|
115 |
+
default=100,
|
116 |
+
help="Number of steps for the warmup in the lr scheduler."
|
117 |
+
)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
"--warmup_percentage",
|
121 |
+
type=float,
|
122 |
+
default=0.08,
|
123 |
+
help="Number of steps for the warmup in the lr scheduler."
|
124 |
+
)
|
125 |
+
return parser
|
126 |
+
|
127 |
+
|
128 |
+
class RetrievalQAEmbedder(torch.nn.Module):
|
129 |
+
def __init__(self, sent_encoder):
|
130 |
+
super(RetrievalQAEmbedder, self).__init__()
|
131 |
+
dim = sent_encoder.config.hidden_size
|
132 |
+
self.bert_query = sent_encoder
|
133 |
+
self.output_dim = 128
|
134 |
+
self.project_query = torch.nn.Linear(dim, self.output_dim, bias=False)
|
135 |
+
self.project_doc = torch.nn.Linear(dim, self.output_dim, bias=False)
|
136 |
+
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
|
137 |
+
|
138 |
+
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
139 |
+
# reproduces BERT forward pass with checkpointing
|
140 |
+
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
|
141 |
+
return self.bert_query(input_ids, attention_mask=attention_mask)[1]
|
142 |
+
else:
|
143 |
+
# prepare implicit variables
|
144 |
+
device = input_ids.device
|
145 |
+
input_shape = input_ids.size()
|
146 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
147 |
+
head_mask = [None] * self.bert_query.config.num_hidden_layers
|
148 |
+
extended_attention_mask: torch.Tensor = self.bert_query.get_extended_attention_mask(
|
149 |
+
attention_mask, input_shape, device
|
150 |
+
)
|
151 |
+
|
152 |
+
# define function for checkpointing
|
153 |
+
def partial_encode(*inputs):
|
154 |
+
encoder_outputs = self.bert_query.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask, )
|
155 |
+
sequence_output = encoder_outputs[0]
|
156 |
+
pooled_output = self.bert_query.pooler(sequence_output)
|
157 |
+
return pooled_output
|
158 |
+
|
159 |
+
# run embedding layer on everything at once
|
160 |
+
embedding_output = self.bert_query.embeddings(
|
161 |
+
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
|
162 |
+
)
|
163 |
+
# run encoding and pooling on one mini-batch at a time
|
164 |
+
pooled_output_list = []
|
165 |
+
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
|
166 |
+
b_embedding_output = embedding_output[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
|
167 |
+
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
|
168 |
+
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
|
169 |
+
pooled_output_list.append(pooled_output)
|
170 |
+
return torch.cat(pooled_output_list, dim=0)
|
171 |
+
|
172 |
+
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
|
173 |
+
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
|
174 |
+
return self.project_query(q_reps)
|
175 |
+
|
176 |
+
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
|
177 |
+
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
|
178 |
+
return self.project_doc(a_reps)
|
179 |
+
|
180 |
+
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
|
181 |
+
device = q_ids.device
|
182 |
+
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
|
183 |
+
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
|
184 |
+
compare_scores = torch.mm(q_reps, a_reps.t())
|
185 |
+
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
|
186 |
+
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
|
187 |
+
loss = (loss_qa + loss_aq) / 2
|
188 |
+
return loss
|
189 |
+
|
190 |
+
|
191 |
+
class ELI5DatasetQARetriever(Dataset):
|
192 |
+
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
|
193 |
+
self.data = examples_array
|
194 |
+
self.answer_thres = extra_answer_threshold
|
195 |
+
self.min_length = min_answer_length
|
196 |
+
self.training = training
|
197 |
+
self.n_samples = self.data.num_rows if n_samples is None else n_samples
|
198 |
+
|
199 |
+
def __len__(self):
|
200 |
+
return self.n_samples
|
201 |
+
|
202 |
+
def make_example(self, idx):
|
203 |
+
example = self.data[idx]
|
204 |
+
question = example["title"]
|
205 |
+
if self.training:
|
206 |
+
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
|
207 |
+
answer_tab = choice(answers).split(" ")
|
208 |
+
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
|
209 |
+
answer_span = " ".join(answer_tab[start_idx:])
|
210 |
+
else:
|
211 |
+
answer_span = example["answers"]["text"][0]
|
212 |
+
return question, answer_span
|
213 |
+
|
214 |
+
def __getitem__(self, idx):
|
215 |
+
return self.make_example(idx % self.data.num_rows)
|
216 |
+
|
217 |
+
|
218 |
+
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64):
|
219 |
+
q_ls = [q for q, a in qa_list]
|
220 |
+
a_ls = [a for q, a in qa_list]
|
221 |
+
q_toks = tokenizer(q_ls, padding="max_length", max_length=max_len, truncation=True)
|
222 |
+
q_ids, q_mask = (
|
223 |
+
torch.LongTensor(q_toks["input_ids"]),
|
224 |
+
torch.LongTensor(q_toks["attention_mask"])
|
225 |
+
)
|
226 |
+
a_toks = tokenizer(a_ls, padding="max_length", max_length=max_len, truncation=True)
|
227 |
+
a_ids, a_mask = (
|
228 |
+
torch.LongTensor(a_toks["input_ids"]),
|
229 |
+
torch.LongTensor(a_toks["attention_mask"]),
|
230 |
+
)
|
231 |
+
return q_ids, q_mask, a_ids, a_mask
|
232 |
+
|
233 |
+
|
234 |
+
def evaluate_qa_retriever(model, data_loader):
|
235 |
+
# make iterator
|
236 |
+
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
237 |
+
tot_loss = 0.0
|
238 |
+
with torch.no_grad():
|
239 |
+
for step, batch in enumerate(epoch_iterator):
|
240 |
+
q_ids, q_mask, a_ids, a_mask = batch
|
241 |
+
loss = model(q_ids, q_mask, a_ids, a_mask)
|
242 |
+
tot_loss += loss.item()
|
243 |
+
return tot_loss / (step + 1)
|
244 |
+
|
245 |
+
|
246 |
+
def train(config):
|
247 |
+
set_seed(42)
|
248 |
+
args = config["args"]
|
249 |
+
data_files = {"train": "train.json", "validation": "validation.json", "test": "test.json"}
|
250 |
+
eli5 = load_dataset(args.dataset_name, data_files=data_files)
|
251 |
+
|
252 |
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
253 |
+
accelerator = Accelerator()
|
254 |
+
# Make one log on every process with the configuration for debugging.
|
255 |
+
logging.basicConfig(
|
256 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
257 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
258 |
+
level=logging.INFO,
|
259 |
+
)
|
260 |
+
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
261 |
+
logger.info(accelerator.state)
|
262 |
+
|
263 |
+
# prepare torch Dataset objects
|
264 |
+
train_dataset = ELI5DatasetQARetriever(eli5['train'], training=True)
|
265 |
+
valid_dataset = ELI5DatasetQARetriever(eli5['validation'], training=False)
|
266 |
+
|
267 |
+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
|
268 |
+
base_model = AutoModel.from_pretrained(args.pretrained_model_name)
|
269 |
+
|
270 |
+
model = RetrievalQAEmbedder(base_model)
|
271 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
272 |
+
optimizer_grouped_parameters = [
|
273 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
274 |
+
'weight_decay': args.weight_decay},
|
275 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
276 |
+
]
|
277 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
|
278 |
+
|
279 |
+
model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
|
280 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
|
281 |
+
sampler=RandomSampler(train_dataset), collate_fn=model_collate_fn)
|
282 |
+
|
283 |
+
model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
|
284 |
+
eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
|
285 |
+
sampler=SequentialSampler(valid_dataset), collate_fn=model_collate_fn)
|
286 |
+
|
287 |
+
# train the model
|
288 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
|
289 |
+
train_dataloader, eval_dataloader)
|
290 |
+
# Scheduler and math around the number of training steps.
|
291 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
292 |
+
if args.max_train_steps is None:
|
293 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
294 |
+
else:
|
295 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
296 |
+
|
297 |
+
num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
|
298 |
+
args.warmup_percentage)
|
299 |
+
scheduler = get_scheduler(
|
300 |
+
name=args.lr_scheduler_type,
|
301 |
+
optimizer=optimizer,
|
302 |
+
num_warmup_steps=args.num_warmup_steps,
|
303 |
+
num_training_steps=args.max_train_steps,
|
304 |
+
)
|
305 |
+
|
306 |
+
# Train!
|
307 |
+
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
308 |
+
|
309 |
+
logger.info("***** Running training *****")
|
310 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
311 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
312 |
+
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
313 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
314 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
315 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
316 |
+
logger.info(f" Warmup steps = {num_warmup_steps}")
|
317 |
+
logger.info(f" Logging training progress every {args.log_freq} optimization steps")
|
318 |
+
|
319 |
+
loc_loss = 0.0
|
320 |
+
current_loss = 0.0
|
321 |
+
checkpoint_step = 0
|
322 |
+
|
323 |
+
completed_steps = checkpoint_step
|
324 |
+
progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
|
325 |
+
disable=not accelerator.is_local_main_process)
|
326 |
+
for epoch in range(args.num_train_epochs):
|
327 |
+
model.train()
|
328 |
+
batch = next(iter(train_dataloader))
|
329 |
+
for step in range(1000):
|
330 |
+
#for step, batch in enumerate(train_dataloader, start=checkpoint_step):
|
331 |
+
# model inputs
|
332 |
+
q_ids, q_mask, a_ids, a_mask = batch
|
333 |
+
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
|
334 |
+
loss = pre_loss.sum() / args.gradient_accumulation_steps
|
335 |
+
accelerator.backward(loss)
|
336 |
+
loc_loss += loss.item()
|
337 |
+
if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
|
338 |
+
current_loss = loc_loss
|
339 |
+
optimizer.step()
|
340 |
+
scheduler.step()
|
341 |
+
optimizer.zero_grad()
|
342 |
+
progress_bar.update(1)
|
343 |
+
progress_bar.set_postfix(loss=loc_loss)
|
344 |
+
loc_loss = 0
|
345 |
+
completed_steps += 1
|
346 |
+
|
347 |
+
if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
|
348 |
+
accelerator.wait_for_everyone()
|
349 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
350 |
+
eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
|
351 |
+
logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
|
352 |
+
if args.wandb and accelerator.is_local_main_process:
|
353 |
+
import wandb
|
354 |
+
wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
|
355 |
+
|
356 |
+
if completed_steps >= args.max_train_steps:
|
357 |
+
break
|
358 |
+
|
359 |
+
logger.info("Saving model {}".format(args.model_save_name))
|
360 |
+
accelerator.wait_for_everyone()
|
361 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
362 |
+
accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
|
363 |
+
eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
|
364 |
+
logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
|
365 |
+
|
366 |
+
|
367 |
+
if __name__ == "__main__":
|
368 |
+
parser = get_parser()
|
369 |
+
parser.add_argument(
|
370 |
+
"--wandb",
|
371 |
+
action="store_true",
|
372 |
+
help="Whether to use W&B logging",
|
373 |
+
)
|
374 |
+
main_args, _ = parser.parse_known_args()
|
375 |
+
config = {"args": main_args}
|
376 |
+
if main_args.wandb:
|
377 |
+
import wandb
|
378 |
+
wandb.init(project="Retriever")
|
379 |
+
|
380 |
+
train(config=config)
|
381 |
+
|
training/run_retriever_no_trainer_gpl.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import List, Any, Union, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import ujson
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from accelerate.utils import set_seed
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
from transformers import get_scheduler, AutoTokenizer, AutoModel, AdamW, SchedulerType, PreTrainedTokenizerBase, AutoModelForSequenceClassification, BatchEncoding
|
16 |
+
from transformers.file_utils import PaddingStrategy
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def get_parser():
|
22 |
+
parser = argparse.ArgumentParser(description="Train LFQA retriever")
|
23 |
+
parser.add_argument(
|
24 |
+
"--dpr_input_file",
|
25 |
+
type=str,
|
26 |
+
help="DPR formatted input file with question/positive/negative pairs in a JSONL file",
|
27 |
+
)
|
28 |
+
|
29 |
+
parser.add_argument(
|
30 |
+
"--per_device_train_batch_size",
|
31 |
+
type=int,
|
32 |
+
default=32,
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--per_device_eval_batch_size",
|
37 |
+
type=int,
|
38 |
+
default=32,
|
39 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
40 |
+
)
|
41 |
+
|
42 |
+
parser.add_argument(
|
43 |
+
"--max_length",
|
44 |
+
type=int,
|
45 |
+
default=128,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--pretrained_model_name",
|
51 |
+
type=str,
|
52 |
+
default="sentence-transformers/all-MiniLM-L6-v2",
|
53 |
+
)
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
"--ce_model_name",
|
57 |
+
type=str,
|
58 |
+
default="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
59 |
+
)
|
60 |
+
|
61 |
+
parser.add_argument(
|
62 |
+
"--model_save_name",
|
63 |
+
type=str,
|
64 |
+
default="eli5_retriever_model_l-12_h-768_b-512-512",
|
65 |
+
)
|
66 |
+
|
67 |
+
parser.add_argument(
|
68 |
+
"--learning_rate",
|
69 |
+
type=float,
|
70 |
+
default=2e-5,
|
71 |
+
)
|
72 |
+
|
73 |
+
parser.add_argument(
|
74 |
+
"--weight_decay",
|
75 |
+
type=float,
|
76 |
+
default=0.01,
|
77 |
+
)
|
78 |
+
|
79 |
+
parser.add_argument(
|
80 |
+
"--log_freq",
|
81 |
+
type=int,
|
82 |
+
default=500,
|
83 |
+
help="Log train/validation loss every log_freq update steps"
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
"--num_train_epochs",
|
88 |
+
type=int,
|
89 |
+
default=4,
|
90 |
+
)
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"--max_train_steps",
|
94 |
+
type=int,
|
95 |
+
default=None,
|
96 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
97 |
+
)
|
98 |
+
|
99 |
+
parser.add_argument(
|
100 |
+
"--gradient_accumulation_steps",
|
101 |
+
type=int,
|
102 |
+
default=1,
|
103 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
104 |
+
)
|
105 |
+
|
106 |
+
parser.add_argument(
|
107 |
+
"--lr_scheduler_type",
|
108 |
+
type=SchedulerType,
|
109 |
+
default="linear", # this is linear with warmup
|
110 |
+
help="The scheduler type to use.",
|
111 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
112 |
+
)
|
113 |
+
|
114 |
+
parser.add_argument(
|
115 |
+
"--num_warmup_steps",
|
116 |
+
type=int,
|
117 |
+
default=100,
|
118 |
+
help="Number of steps for the warmup in the lr scheduler."
|
119 |
+
)
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--warmup_percentage",
|
123 |
+
type=float,
|
124 |
+
default=0.08,
|
125 |
+
help="Number of steps for the warmup in the lr scheduler."
|
126 |
+
)
|
127 |
+
return parser
|
128 |
+
|
129 |
+
|
130 |
+
@dataclass
|
131 |
+
class InputExample:
|
132 |
+
guid: str = ""
|
133 |
+
texts: List[str] = None
|
134 |
+
label: Union[int, float] = 0
|
135 |
+
|
136 |
+
|
137 |
+
class DPRDataset(Dataset):
|
138 |
+
"""
|
139 |
+
Dataset DPR format of question, answers, positive, negative, and hard negative passages
|
140 |
+
See https://github.com/facebookresearch/DPR#retriever-input-data-format for more details
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, file_path: str, include_all_positive: bool = False) -> None:
|
144 |
+
super().__init__()
|
145 |
+
with open(file_path, "r") as fp:
|
146 |
+
self.data = []
|
147 |
+
|
148 |
+
def dpr_example_to_input_example(idx, dpr_item):
|
149 |
+
examples = []
|
150 |
+
for p_idx, p_item in enumerate(dpr_item["positive_ctxs"]):
|
151 |
+
for n_idx, n_item in enumerate(dpr_item["negative_ctxs"]):
|
152 |
+
examples.append(InputExample(guid=[idx, p_idx, n_idx], texts=[dpr_item["question"],
|
153 |
+
p_item["text"],
|
154 |
+
n_item["text"]]))
|
155 |
+
if not include_all_positive:
|
156 |
+
break
|
157 |
+
return examples
|
158 |
+
|
159 |
+
for idx, line in enumerate(fp):
|
160 |
+
self.data.extend(dpr_example_to_input_example(idx, ujson.loads(line)))
|
161 |
+
|
162 |
+
def __len__(self):
|
163 |
+
return len(self.data)
|
164 |
+
|
165 |
+
def __getitem__(self, index):
|
166 |
+
return self.data[index]
|
167 |
+
|
168 |
+
|
169 |
+
def dpr_collate_fn(batch):
|
170 |
+
query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
|
171 |
+
query, pos, neg = zip(*[example.texts for example in batch])
|
172 |
+
return (query_id, pos_id, neg_id), (query, pos, neg)
|
173 |
+
|
174 |
+
|
175 |
+
# Mean Pooling - Take attention mask into account for correct averaging
|
176 |
+
def mean_pooling(model_output, attention_mask):
|
177 |
+
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
178 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
179 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
180 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
181 |
+
return sum_embeddings / sum_mask
|
182 |
+
|
183 |
+
|
184 |
+
@dataclass
|
185 |
+
class CrossEncoderCollator:
|
186 |
+
tokenizer: PreTrainedTokenizerBase
|
187 |
+
model: Any
|
188 |
+
target_tokenizer: PreTrainedTokenizerBase
|
189 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
190 |
+
max_length: Optional[int] = None
|
191 |
+
pad_to_multiple_of: Optional[int] = None
|
192 |
+
return_tensors: str = "pt"
|
193 |
+
|
194 |
+
def __call__(self, batch):
|
195 |
+
query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
|
196 |
+
query, pos_passage, neg_passage = zip(*[example.texts for example in batch])
|
197 |
+
batch_input: List[List[str]] = list(zip(query, pos_passage)) + list(zip(query, neg_passage))
|
198 |
+
features = self.tokenizer(batch_input, padding=self.padding, truncation=True,
|
199 |
+
return_tensors=self.return_tensors)
|
200 |
+
with torch.no_grad():
|
201 |
+
scores = self.model(**features).logits
|
202 |
+
|
203 |
+
labels = scores[:len(query)] - scores[len(query):]
|
204 |
+
batch_input: List[str] = list(query) + list(pos_passage) + list(neg_passage)
|
205 |
+
#breakpoint()
|
206 |
+
encoded_input = self.target_tokenizer(batch_input, padding=True, truncation=True,
|
207 |
+
max_length=256, return_tensors='pt')
|
208 |
+
|
209 |
+
encoded_input["labels"] = labels
|
210 |
+
|
211 |
+
return encoded_input
|
212 |
+
|
213 |
+
|
214 |
+
class RetrievalQAEmbedder(torch.nn.Module):
|
215 |
+
def __init__(self, sent_encoder, sent_tokenizer, batch_size:int = 32):
|
216 |
+
super(RetrievalQAEmbedder, self).__init__()
|
217 |
+
dim = sent_encoder.config.hidden_size
|
218 |
+
self.model = sent_encoder
|
219 |
+
self.tokenizer = sent_tokenizer
|
220 |
+
self.scale = 1
|
221 |
+
self.similarity_fct = 'dot'
|
222 |
+
self.batch_size = 32
|
223 |
+
self.loss_fct = nn.MSELoss()
|
224 |
+
|
225 |
+
def forward(self, examples: BatchEncoding):
|
226 |
+
# Tokenize sentences
|
227 |
+
labels = examples.pop("labels")
|
228 |
+
# Compute token embeddings
|
229 |
+
model_output = self.model(**examples)
|
230 |
+
|
231 |
+
examples["labels"] = labels
|
232 |
+
|
233 |
+
# Perform pooling. In this case, mean pooling
|
234 |
+
sentence_embeddings = mean_pooling(model_output, examples['attention_mask'])
|
235 |
+
target_shape = (3, self.batch_size, sentence_embeddings.shape[-1])
|
236 |
+
sentence_embeddings_reshaped = torch.reshape(sentence_embeddings, target_shape)
|
237 |
+
|
238 |
+
#breakpoint()
|
239 |
+
|
240 |
+
embeddings_query = sentence_embeddings_reshaped[0]
|
241 |
+
embeddings_pos = sentence_embeddings_reshaped[1]
|
242 |
+
embeddings_neg = sentence_embeddings_reshaped[2]
|
243 |
+
|
244 |
+
if self.similarity_fct == 'cosine':
|
245 |
+
embeddings_query = F.normalize(embeddings_query, p=2, dim=1)
|
246 |
+
embeddings_pos = F.normalize(embeddings_pos, p=2, dim=1)
|
247 |
+
embeddings_neg = F.normalize(embeddings_neg, p=2, dim=1)
|
248 |
+
|
249 |
+
scores_pos = (embeddings_query * embeddings_pos).sum(dim=-1) * self.scale
|
250 |
+
scores_neg = (embeddings_query * embeddings_neg).sum(dim=-1) * self.scale
|
251 |
+
margin_pred = scores_pos - scores_neg
|
252 |
+
#breakpoint()
|
253 |
+
return self.loss_fct(margin_pred, labels.squeeze())
|
254 |
+
|
255 |
+
|
256 |
+
def evaluate_qa_retriever(model, data_loader):
|
257 |
+
# make iterator
|
258 |
+
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
|
259 |
+
tot_loss = 0.0
|
260 |
+
with torch.no_grad():
|
261 |
+
for step, batch in enumerate(epoch_iterator):
|
262 |
+
q_ids, q_mask, a_ids, a_mask = batch
|
263 |
+
loss = model(q_ids, q_mask, a_ids, a_mask)
|
264 |
+
tot_loss += loss.item()
|
265 |
+
return tot_loss / (step + 1)
|
266 |
+
|
267 |
+
|
268 |
+
def train(config):
|
269 |
+
set_seed(42)
|
270 |
+
args = config["args"]
|
271 |
+
|
272 |
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
273 |
+
accelerator = Accelerator()
|
274 |
+
# Make one log on every process with the configuration for debugging.
|
275 |
+
logging.basicConfig(
|
276 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
277 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
278 |
+
level=logging.INFO,
|
279 |
+
)
|
280 |
+
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
281 |
+
logger.info(accelerator.state)
|
282 |
+
|
283 |
+
# prepare torch Dataset objects
|
284 |
+
train_dataset = DPRDataset(file_path=args.dpr_input_file)
|
285 |
+
valid_dataset = Dataset()
|
286 |
+
|
287 |
+
base_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
|
288 |
+
base_model = AutoModel.from_pretrained(args.pretrained_model_name)
|
289 |
+
|
290 |
+
ce_tokenizer = AutoTokenizer.from_pretrained(args.ce_model_name)
|
291 |
+
ce_model = AutoModelForSequenceClassification.from_pretrained(args.ce_model_name)
|
292 |
+
_ = ce_model.eval()
|
293 |
+
|
294 |
+
model = RetrievalQAEmbedder(base_model, base_tokenizer)
|
295 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
296 |
+
optimizer_grouped_parameters = [
|
297 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
298 |
+
'weight_decay': args.weight_decay},
|
299 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
300 |
+
]
|
301 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
302 |
+
|
303 |
+
cec = CrossEncoderCollator(model=ce_model, tokenizer=ce_tokenizer, target_tokenizer=base_tokenizer)
|
304 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
|
305 |
+
sampler=RandomSampler(train_dataset), collate_fn=cec)
|
306 |
+
|
307 |
+
eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
|
308 |
+
sampler=SequentialSampler(valid_dataset), collate_fn=cec)
|
309 |
+
|
310 |
+
# train the model
|
311 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
|
312 |
+
train_dataloader, eval_dataloader)
|
313 |
+
# Scheduler and math around the number of training steps.
|
314 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
315 |
+
if args.max_train_steps is None:
|
316 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
317 |
+
else:
|
318 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
319 |
+
|
320 |
+
num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
|
321 |
+
args.warmup_percentage)
|
322 |
+
scheduler = get_scheduler(
|
323 |
+
name=args.lr_scheduler_type,
|
324 |
+
optimizer=optimizer,
|
325 |
+
num_warmup_steps=args.num_warmup_steps,
|
326 |
+
num_training_steps=args.max_train_steps,
|
327 |
+
)
|
328 |
+
|
329 |
+
# Train!
|
330 |
+
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
331 |
+
|
332 |
+
logger.info("***** Running training *****")
|
333 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
334 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
335 |
+
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
336 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
337 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
338 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
339 |
+
logger.info(f" Warmup steps = {num_warmup_steps}")
|
340 |
+
logger.info(f" Logging training progress every {args.log_freq} optimization steps")
|
341 |
+
|
342 |
+
loc_loss = 0.0
|
343 |
+
current_loss = 0.0
|
344 |
+
checkpoint_step = 0
|
345 |
+
|
346 |
+
completed_steps = checkpoint_step
|
347 |
+
progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
|
348 |
+
disable=not accelerator.is_local_main_process)
|
349 |
+
for epoch in range(args.num_train_epochs):
|
350 |
+
model.train()
|
351 |
+
for step, batch in enumerate(train_dataloader, start=checkpoint_step):
|
352 |
+
# model inputs
|
353 |
+
pre_loss = model(batch)
|
354 |
+
loss = pre_loss / args.gradient_accumulation_steps
|
355 |
+
accelerator.backward(loss)
|
356 |
+
loc_loss += loss.item()
|
357 |
+
if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
|
358 |
+
current_loss = loc_loss
|
359 |
+
optimizer.step()
|
360 |
+
scheduler.step()
|
361 |
+
optimizer.zero_grad()
|
362 |
+
progress_bar.update(1)
|
363 |
+
progress_bar.set_postfix(loss=loc_loss)
|
364 |
+
loc_loss = 0
|
365 |
+
completed_steps += 1
|
366 |
+
|
367 |
+
if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
|
368 |
+
# accelerator.wait_for_everyone()
|
369 |
+
# unwrapped_model = accelerator.unwrap_model(model)
|
370 |
+
# eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
|
371 |
+
eval_loss = 0
|
372 |
+
logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
|
373 |
+
if args.wandb and accelerator.is_local_main_process:
|
374 |
+
import wandb
|
375 |
+
wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
|
376 |
+
|
377 |
+
if completed_steps >= args.max_train_steps:
|
378 |
+
break
|
379 |
+
|
380 |
+
logger.info("Saving model {}".format(args.model_save_name))
|
381 |
+
accelerator.wait_for_everyone()
|
382 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
383 |
+
accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
|
384 |
+
eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
|
385 |
+
logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == "__main__":
|
389 |
+
parser = get_parser()
|
390 |
+
parser.add_argument(
|
391 |
+
"--wandb",
|
392 |
+
action="store_true",
|
393 |
+
help="Whether to use W&B logging",
|
394 |
+
)
|
395 |
+
main_args, _ = parser.parse_known_args()
|
396 |
+
config = {"args": main_args}
|
397 |
+
if main_args.wandb:
|
398 |
+
import wandb
|
399 |
+
|
400 |
+
wandb.init(project="Retriever")
|
401 |
+
|
402 |
+
train(config=config)
|
403 |
+
|
training/run_seq2seq_no_trainer.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import re
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from accelerate.utils import set_seed
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \
|
13 |
+
DataCollatorWithPadding
|
14 |
+
|
15 |
+
from datasets import load_dataset
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def get_parser():
|
21 |
+
parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model")
|
22 |
+
parser.add_argument(
|
23 |
+
"--dataset_name",
|
24 |
+
type=str,
|
25 |
+
default="vblagoje/lfqa",
|
26 |
+
help="The name of the dataset to use (via the datasets library).",
|
27 |
+
)
|
28 |
+
|
29 |
+
parser.add_argument(
|
30 |
+
"--per_device_train_batch_size",
|
31 |
+
type=int,
|
32 |
+
default=4,
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--per_device_eval_batch_size",
|
37 |
+
type=int,
|
38 |
+
default=4,
|
39 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
40 |
+
)
|
41 |
+
|
42 |
+
parser.add_argument(
|
43 |
+
"--pretrained_model_name",
|
44 |
+
type=str,
|
45 |
+
default="facebook/bart-large",
|
46 |
+
)
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
"--model_save_name",
|
50 |
+
type=str,
|
51 |
+
default="eli5_bart_model",
|
52 |
+
)
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--learning_rate",
|
56 |
+
type=float,
|
57 |
+
default=2e-4,
|
58 |
+
)
|
59 |
+
|
60 |
+
parser.add_argument(
|
61 |
+
"--weight_decay",
|
62 |
+
type=float,
|
63 |
+
default=0.0,
|
64 |
+
help="Weight decay to use."
|
65 |
+
)
|
66 |
+
|
67 |
+
parser.add_argument(
|
68 |
+
"--log_freq",
|
69 |
+
type=int,
|
70 |
+
default=100,
|
71 |
+
help="Log train/validation loss every log_freq update steps"
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
"--ignore_pad_token_for_loss",
|
76 |
+
type=bool,
|
77 |
+
default=True,
|
78 |
+
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
|
79 |
+
)
|
80 |
+
|
81 |
+
parser.add_argument(
|
82 |
+
"--num_train_epochs",
|
83 |
+
type=int,
|
84 |
+
default=3,
|
85 |
+
)
|
86 |
+
|
87 |
+
parser.add_argument(
|
88 |
+
"--max_train_steps",
|
89 |
+
type=int,
|
90 |
+
default=None,
|
91 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
92 |
+
)
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
"--gradient_accumulation_steps",
|
96 |
+
type=int,
|
97 |
+
default=16,
|
98 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
99 |
+
)
|
100 |
+
|
101 |
+
parser.add_argument(
|
102 |
+
"--pad_to_max_length",
|
103 |
+
action="store_true",
|
104 |
+
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
|
105 |
+
)
|
106 |
+
|
107 |
+
parser.add_argument(
|
108 |
+
"--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
|
109 |
+
)
|
110 |
+
|
111 |
+
parser.add_argument(
|
112 |
+
"--max_source_length",
|
113 |
+
type=int,
|
114 |
+
default=1024,
|
115 |
+
help="The maximum total input sequence length after "
|
116 |
+
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
|
117 |
+
)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
"--max_target_length",
|
121 |
+
type=int,
|
122 |
+
default=360,
|
123 |
+
help="The maximum total sequence length for target text after "
|
124 |
+
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
125 |
+
)
|
126 |
+
|
127 |
+
parser.add_argument(
|
128 |
+
"--lr_scheduler_type",
|
129 |
+
type=SchedulerType,
|
130 |
+
default="linear", # this is linear with warmup
|
131 |
+
help="The scheduler type to use.",
|
132 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
133 |
+
)
|
134 |
+
|
135 |
+
parser.add_argument(
|
136 |
+
"--num_warmup_steps",
|
137 |
+
type=int,
|
138 |
+
default=None,
|
139 |
+
help="Number of steps for the warmup in the lr scheduler."
|
140 |
+
)
|
141 |
+
|
142 |
+
parser.add_argument(
|
143 |
+
"--warmup_percentage",
|
144 |
+
type=float,
|
145 |
+
default=0.08,
|
146 |
+
help="Number of steps for the warmup in the lr scheduler."
|
147 |
+
)
|
148 |
+
return parser
|
149 |
+
|
150 |
+
|
151 |
+
def cleanup_references(text):
|
152 |
+
# URL reference where we need to remove both the link text and URL
|
153 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal
|
154 |
+
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
|
155 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
|
156 |
+
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
|
157 |
+
|
158 |
+
# URL reference where we need to preserve link text but remove URL
|
159 |
+
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
|
160 |
+
# At the outbreak of the Civil War, Leyburn left his church and joined the South.
|
161 |
+
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
|
162 |
+
|
163 |
+
# lastly remove just dangling _URL_[0-9]_ URL references
|
164 |
+
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
|
165 |
+
return result
|
166 |
+
|
167 |
+
|
168 |
+
def clean_answer(text):
|
169 |
+
result = cleanup_references(text)
|
170 |
+
result = result.replace("\n", " ")
|
171 |
+
result = re.sub(r"\s\s+", " ", result)
|
172 |
+
result = re.sub(r"BULLET::::-", "", result)
|
173 |
+
return result.strip()
|
174 |
+
|
175 |
+
|
176 |
+
def clean_question(text):
|
177 |
+
result = cleanup_references(text)
|
178 |
+
result = result.replace("\n", " ")
|
179 |
+
result = re.sub(r"\s\s+", " ", result)
|
180 |
+
result = result.replace("[deleted]", "")
|
181 |
+
return result.lower().strip()
|
182 |
+
|
183 |
+
|
184 |
+
def prepare_support_docs(example):
|
185 |
+
provenances = example["output"][-1]["provenance"]
|
186 |
+
context = "<P> " + " <P> ".join([p["text"] for p in provenances])
|
187 |
+
return {"context": context}
|
188 |
+
|
189 |
+
|
190 |
+
def preprocess_eli5(examples, **fn_kwargs):
|
191 |
+
document_cache = fn_kwargs["document_cache"]
|
192 |
+
training = fn_kwargs.get("training", True)
|
193 |
+
extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3)
|
194 |
+
include_selftext = fn_kwargs.get("include_selftext", False)
|
195 |
+
exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", [])
|
196 |
+
|
197 |
+
questions, contexts, answers = [], [], []
|
198 |
+
for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"],
|
199 |
+
examples["answers"]):
|
200 |
+
accepted_answer_idx = []
|
201 |
+
if training:
|
202 |
+
accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if
|
203 |
+
score > extra_answer_threshold]
|
204 |
+
if not training or not accepted_answer_idx:
|
205 |
+
accepted_answer_idx = [0]
|
206 |
+
document = document_cache[q_id]
|
207 |
+
for idx in accepted_answer_idx:
|
208 |
+
skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns])
|
209 |
+
if skip_answer:
|
210 |
+
continue
|
211 |
+
if include_selftext:
|
212 |
+
questions.append(clean_question(f"{question} {selftext}"))
|
213 |
+
else:
|
214 |
+
questions.append(clean_question(question))
|
215 |
+
contexts.append(document.lower().strip())
|
216 |
+
answers.append(clean_answer(answer["text"][idx]))
|
217 |
+
|
218 |
+
return {"question": questions, "context": contexts, "answer": answers}
|
219 |
+
|
220 |
+
|
221 |
+
def eval_qa_s2s_epoch(model, dataloader, accelerator, args):
|
222 |
+
model.eval()
|
223 |
+
num_eval_steps = math.ceil(len(dataloader))
|
224 |
+
progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process)
|
225 |
+
total_loss = 0.
|
226 |
+
with torch.no_grad():
|
227 |
+
for step, batch in enumerate(dataloader):
|
228 |
+
outputs = model(**batch)
|
229 |
+
loss = outputs.loss
|
230 |
+
total_loss += loss.item()
|
231 |
+
progress_bar.update(1)
|
232 |
+
progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3))
|
233 |
+
return total_loss / (step + 1)
|
234 |
+
|
235 |
+
|
236 |
+
def train(config):
|
237 |
+
set_seed(42)
|
238 |
+
args = config["args"]
|
239 |
+
eli5 = load_dataset(args.dataset_name)
|
240 |
+
|
241 |
+
support_docs = load_dataset("vblagoje/lfqa_support_docs")
|
242 |
+
|
243 |
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
244 |
+
accelerator = Accelerator()
|
245 |
+
# Make one log on every process with the configuration for debugging.
|
246 |
+
logging.basicConfig(
|
247 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
248 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
249 |
+
level=logging.INFO,
|
250 |
+
)
|
251 |
+
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
252 |
+
logger.info(accelerator.state)
|
253 |
+
|
254 |
+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
|
255 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name)
|
256 |
+
|
257 |
+
# Optimizer
|
258 |
+
# Split weights in two groups, one with weight decay and the other not.
|
259 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
260 |
+
optimizer_grouped_parameters = [
|
261 |
+
{
|
262 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
263 |
+
"weight_decay": args.weight_decay,
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
267 |
+
"weight_decay": 0.0,
|
268 |
+
},
|
269 |
+
]
|
270 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
|
271 |
+
|
272 |
+
processed_datasets = {}
|
273 |
+
support_docs_prepared = {}
|
274 |
+
with accelerator.main_process_first():
|
275 |
+
for split in ["train", "validation"]:
|
276 |
+
support_docs_prepared[split] = support_docs[split].map(prepare_support_docs,
|
277 |
+
batched=False,
|
278 |
+
cache_file_name=f"./support_docs_{split}.arrow",
|
279 |
+
load_from_cache_file=not args.overwrite_cache,
|
280 |
+
desc="Preparing support docs",
|
281 |
+
)
|
282 |
+
column_names = eli5["train"].column_names
|
283 |
+
for split in ["train", "validation"]:
|
284 |
+
d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split],
|
285 |
+
desc=f"Adding support docs to LFQA {split}")])
|
286 |
+
processed_datasets[split] = eli5[split].map(preprocess_eli5,
|
287 |
+
batched=True,
|
288 |
+
remove_columns=column_names,
|
289 |
+
cache_file_name=f"./processed_datasets_{split}.arrow",
|
290 |
+
load_from_cache_file=not args.overwrite_cache,
|
291 |
+
desc="Preparing dataset for tokenization",
|
292 |
+
fn_kwargs={"document_cache": d_cache,
|
293 |
+
"training": split == "train",
|
294 |
+
"exclude_answer_patterns": [re.compile("not sure what you"),
|
295 |
+
re.compile("\n\n >")]}
|
296 |
+
)
|
297 |
+
|
298 |
+
padding = "max_length" if args.pad_to_max_length else False
|
299 |
+
# Temporarily set max_target_length for training.
|
300 |
+
max_target_length = args.max_target_length
|
301 |
+
|
302 |
+
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
303 |
+
|
304 |
+
def tokenize_dataset(examples):
|
305 |
+
inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])]
|
306 |
+
targets = examples["answer"]
|
307 |
+
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
308 |
+
|
309 |
+
# Setup the tokenizer for targets
|
310 |
+
with tokenizer.as_target_tokenizer():
|
311 |
+
labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True,
|
312 |
+
return_tensors="np")
|
313 |
+
|
314 |
+
model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist()
|
315 |
+
# replace pad_token_id with label_pad_token_id to avoid loss calculation on those tokens
|
316 |
+
labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id,
|
317 |
+
label_pad_token_id, labels["input_ids"])
|
318 |
+
|
319 |
+
model_inputs["labels"] = labels["input_ids"][:, 1:].tolist()
|
320 |
+
return model_inputs
|
321 |
+
|
322 |
+
tokenized_datasets = {}
|
323 |
+
with accelerator.main_process_first():
|
324 |
+
for split, dataset in processed_datasets.items():
|
325 |
+
tokenized_datasets[split] = dataset.map(
|
326 |
+
tokenize_dataset,
|
327 |
+
batched=True,
|
328 |
+
cache_file_name=f"./tokenized_dataset_{split}.arrow",
|
329 |
+
remove_columns=dataset.column_names,
|
330 |
+
load_from_cache_file=not args.overwrite_cache,
|
331 |
+
desc="Running tokenizer on dataset"
|
332 |
+
)
|
333 |
+
|
334 |
+
train_dataset = tokenized_datasets["train"]
|
335 |
+
eval_dataset = tokenized_datasets["validation"]
|
336 |
+
train_dataset.set_format(type='torch')
|
337 |
+
eval_dataset.set_format(type='torch')
|
338 |
+
|
339 |
+
data_collator = DataCollatorWithPadding(tokenizer, "max_length")
|
340 |
+
|
341 |
+
# first epoch we don't shuffle
|
342 |
+
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size,
|
343 |
+
collate_fn=data_collator)
|
344 |
+
eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator)
|
345 |
+
|
346 |
+
# train the model
|
347 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,
|
348 |
+
eval_dataloader)
|
349 |
+
# Scheduler and math around the number of training steps.
|
350 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
351 |
+
if args.max_train_steps is None:
|
352 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
353 |
+
else:
|
354 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
355 |
+
|
356 |
+
num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
|
357 |
+
args.warmup_percentage)
|
358 |
+
scheduler = get_scheduler(
|
359 |
+
name=args.lr_scheduler_type,
|
360 |
+
optimizer=optimizer,
|
361 |
+
num_warmup_steps=num_warmup_steps,
|
362 |
+
num_training_steps=args.max_train_steps,
|
363 |
+
)
|
364 |
+
# Train!
|
365 |
+
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
366 |
+
|
367 |
+
logger.info("***** Running training *****")
|
368 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
369 |
+
logger.info(f" Num eval examples = {len(eval_dataset)}")
|
370 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
371 |
+
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
372 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
373 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
374 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
375 |
+
logger.info(f" Warmup steps = {num_warmup_steps}")
|
376 |
+
logger.info(f" Logging training progress every {args.log_freq} optimization steps")
|
377 |
+
|
378 |
+
# Only show the progress bar once on each machine.
|
379 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
380 |
+
completed_steps = 0
|
381 |
+
switched_train_dataloader = False
|
382 |
+
for epoch in range(args.num_train_epochs):
|
383 |
+
model.train()
|
384 |
+
if epoch > 0 and not switched_train_dataloader:
|
385 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
|
386 |
+
shuffle=True, collate_fn=data_collator)
|
387 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
388 |
+
switched_train_dataloader = True
|
389 |
+
|
390 |
+
for step, batch in enumerate(train_dataloader):
|
391 |
+
outputs = model(**batch)
|
392 |
+
loss = torch.mean(outputs.loss)
|
393 |
+
accelerator.backward(loss)
|
394 |
+
if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
|
395 |
+
optimizer.step()
|
396 |
+
scheduler.step()
|
397 |
+
optimizer.zero_grad()
|
398 |
+
progress_bar.update(1)
|
399 |
+
progress_bar.set_postfix(loss=round(loss.item(), 3))
|
400 |
+
completed_steps += 1
|
401 |
+
|
402 |
+
if completed_steps >= args.max_train_steps:
|
403 |
+
break
|
404 |
+
|
405 |
+
if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
|
406 |
+
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
|
407 |
+
model.train()
|
408 |
+
logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}")
|
409 |
+
if args.wandb and accelerator.is_local_main_process:
|
410 |
+
import wandb
|
411 |
+
wandb.log({"loss": loss.item(),
|
412 |
+
"lr": scheduler.get_last_lr()[0],
|
413 |
+
"validation_loss": validation_loss,
|
414 |
+
"completed_steps": completed_steps})
|
415 |
+
|
416 |
+
logger.info("Saving model {}".format(args.model_save_name))
|
417 |
+
accelerator.wait_for_everyone()
|
418 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
419 |
+
accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
|
420 |
+
|
421 |
+
# Calculating the validation loss over epoch
|
422 |
+
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
|
423 |
+
|
424 |
+
logger.info("Epoch: {}".format(epoch))
|
425 |
+
logger.info("Validation loss: {}".format(validation_loss))
|
426 |
+
|
427 |
+
|
428 |
+
def main():
|
429 |
+
parser = get_parser()
|
430 |
+
parser.add_argument(
|
431 |
+
"--wandb",
|
432 |
+
action="store_true",
|
433 |
+
help="If true, use W&B logging",
|
434 |
+
)
|
435 |
+
main_args, _ = parser.parse_known_args()
|
436 |
+
config = {"args": main_args}
|
437 |
+
if main_args.wandb:
|
438 |
+
import wandb
|
439 |
+
wandb.init(project="Bart_ELI5")
|
440 |
+
train(config=config)
|
441 |
+
|
442 |
+
|
443 |
+
main()
|
444 |
+
|
445 |
+
|
446 |
+
|
util/common.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
6 |
+
'wikidata_info', 'history']
|
7 |
+
|
8 |
+
kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
|
9 |
+
'end_character', 'title', 'section', 'text']
|
10 |
+
|
11 |
+
|
12 |
+
def clean_question(text):
|
13 |
+
result = cleanup_references(text)
|
14 |
+
result = result.replace("\n", " ")
|
15 |
+
result = re.sub(r"\s\s+", " ", result)
|
16 |
+
result = result.replace("[deleted]", "")
|
17 |
+
return result.lower().strip()
|
18 |
+
|
19 |
+
|
20 |
+
def cleanup_references(text):
|
21 |
+
# URL reference where we need to remove both the link text and URL
|
22 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal
|
23 |
+
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
|
24 |
+
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
|
25 |
+
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
|
26 |
+
|
27 |
+
# URL reference where we need to preserve link text but remove URL
|
28 |
+
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
|
29 |
+
# At the outbreak of the Civil War, Leyburn left his church and joined the South.
|
30 |
+
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
|
31 |
+
|
32 |
+
# lastly remove just dangling _URL_[0-9]_ URL references
|
33 |
+
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
|
34 |
+
return result
|
35 |
+
|
36 |
+
|
37 |
+
def clean_answer(text):
|
38 |
+
result = cleanup_references(text)
|
39 |
+
result = result.replace("\n", " ")
|
40 |
+
result = re.sub(r"\s\s+", " ", result)
|
41 |
+
result = re.sub(r"BULLET::::-", "", result)
|
42 |
+
return trim(result.strip())
|
43 |
+
|
44 |
+
|
45 |
+
def trim(text, word_count: int = 100):
|
46 |
+
return " ".join(text.split(" ")[:word_count])
|
47 |
+
|
48 |
+
|
49 |
+
def articles_to_paragraphs(examples):
|
50 |
+
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
|
51 |
+
for bidx, example in enumerate(examples["text"]):
|
52 |
+
last_section = ""
|
53 |
+
for idx, p in enumerate(example["paragraph"]):
|
54 |
+
if "Section::::" in p:
|
55 |
+
last_section = p
|
56 |
+
ids.append(examples["wikipedia_id"][bidx])
|
57 |
+
titles.append(examples["wikipedia_title"][bidx])
|
58 |
+
sections.append(last_section)
|
59 |
+
texts.append(p)
|
60 |
+
start_ps.append(idx)
|
61 |
+
end_ps.append(idx)
|
62 |
+
start_cs.append(0)
|
63 |
+
end_cs.append(len(p))
|
64 |
+
|
65 |
+
return {"wikipedia_id": ids, "title": titles,
|
66 |
+
"section": sections, "text": texts,
|
67 |
+
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
|
68 |
+
"start_character": start_cs,
|
69 |
+
"end_character": end_cs
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
|
74 |
+
res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
|
75 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
|
76 |
+
|
77 |
+
# make a KILT data point
|
78 |
+
# see https://github.com/facebookresearch/KILT#kilt-data-format
|
79 |
+
output = []
|
80 |
+
for a in eli5_example["answers"]["text"]:
|
81 |
+
output.append({"answer": a})
|
82 |
+
|
83 |
+
output.append({"provenance": [
|
84 |
+
# evidence set for the answer from the KILT ks
|
85 |
+
{
|
86 |
+
"wikipedia_id": r["wikipedia_id"], # *mandatory*
|
87 |
+
"title": r["title"],
|
88 |
+
"section": r["section"],
|
89 |
+
"start_paragraph_id": r["start_paragraph_id"],
|
90 |
+
"start_character": r["start_character"],
|
91 |
+
"end_paragraph_id": r["end_paragraph_id"],
|
92 |
+
"end_character": r["end_character"],
|
93 |
+
"text": r["text"],
|
94 |
+
"bleu_score": None, # wrt original evidence
|
95 |
+
"meta": None # dataset/task specific
|
96 |
+
} for r in res_list
|
97 |
+
]})
|
98 |
+
return {"id": eli5_example["q_id"],
|
99 |
+
"input": eli5_example["title"],
|
100 |
+
"output": output, # each element is an answer or provenance (can have multiple of each)
|
101 |
+
"meta": None # dataset/task specific
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
|
106 |
+
query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
|
107 |
+
return_tensors="pt")
|
108 |
+
with torch.no_grad():
|
109 |
+
q_reps = question_model(query["input_ids"].to(device),
|
110 |
+
query["attention_mask"].to(device)).pooler_output
|
111 |
+
return q_reps.cpu().numpy()
|
112 |
+
|
113 |
+
|
114 |
+
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
|
115 |
+
p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
|
116 |
+
truncation=True, return_tensors="pt")
|
117 |
+
with torch.no_grad():
|
118 |
+
a_reps = ctx_model(p["input_ids"].to(device),
|
119 |
+
p["attention_mask"].to(device)).pooler_output
|
120 |
+
return {"embeddings": a_reps.cpu().numpy()}
|
util/create_dpr_training_from_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from sentence_transformers.util import semantic_search, cos_sim
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
from common import clean_answer, clean_question
|
12 |
+
|
13 |
+
|
14 |
+
def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
15 |
+
exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
|
16 |
+
hard_negative_ctxs = []
|
17 |
+
results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
|
18 |
+
score_function=cos_sim)
|
19 |
+
# list if dicts
|
20 |
+
# [{'corpus_id': 8, 'score': -0.019427383318543434},
|
21 |
+
# ...
|
22 |
+
# {'corpus_id': 10, 'score': -0.09040290117263794}]
|
23 |
+
# hard negative are most similar and negatives are most disimilar to embedding_index
|
24 |
+
hard_negative_results = results[0][1:k + 1]
|
25 |
+
assert len(hard_negative_results) > min_count * 2
|
26 |
+
for r in hard_negative_results:
|
27 |
+
example = dataset[r["corpus_id"]]
|
28 |
+
if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
|
29 |
+
for a in example["answers"]["text"]:
|
30 |
+
hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
31 |
+
if len(hard_negative_ctxs) > min_count:
|
32 |
+
break
|
33 |
+
return hard_negative_ctxs[:min_count]
|
34 |
+
|
35 |
+
|
36 |
+
def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
37 |
+
exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
|
38 |
+
negative_ctxs = []
|
39 |
+
random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
|
40 |
+
similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
|
41 |
+
for idx, score in enumerate(similarities):
|
42 |
+
if score < similarity_threshold:
|
43 |
+
example = dataset[random_sample[idx]]
|
44 |
+
for a in example["answers"]["text"]:
|
45 |
+
negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
46 |
+
if len(negative_ctxs) > min_count:
|
47 |
+
break
|
48 |
+
return negative_ctxs[:min_count]
|
49 |
+
|
50 |
+
|
51 |
+
def generate_dpr_training_file(args):
|
52 |
+
embedder = SentenceTransformer(args.embedding_model)
|
53 |
+
|
54 |
+
eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
|
55 |
+
eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
|
56 |
+
eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
|
57 |
+
|
58 |
+
train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
|
59 |
+
show_progress_bar=True)
|
60 |
+
validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
|
61 |
+
show_progress_bar=True)
|
62 |
+
|
63 |
+
test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
|
64 |
+
show_progress_bar=True)
|
65 |
+
exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
|
66 |
+
for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
|
67 |
+
[eli5_train_set, eli5_validation_set, eli5_test_set],
|
68 |
+
[train_set, validation_set, test_set]):
|
69 |
+
min_elements = 3
|
70 |
+
skip_count = 0
|
71 |
+
progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
|
72 |
+
with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
|
73 |
+
for idx, example in enumerate(dataset):
|
74 |
+
negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
75 |
+
hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
76 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
|
77 |
+
not any([p.search(a) for p in exclude_answer_patterns])]
|
78 |
+
if not positive_context:
|
79 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
|
80 |
+
if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
|
81 |
+
json.dump({"id": example["q_id"],
|
82 |
+
"question": clean_question(example["title"]),
|
83 |
+
"positive_ctxs": positive_context[:min_elements],
|
84 |
+
"negative_ctxs": negative_ctxs[:min_elements],
|
85 |
+
"hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
|
86 |
+
fp.write("\n")
|
87 |
+
else:
|
88 |
+
skip_count += 1
|
89 |
+
progress_bar.update(1)
|
90 |
+
|
91 |
+
print(f"Skipped {skip_count} questions")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
|
96 |
+
parser.add_argument(
|
97 |
+
"--embedding_model",
|
98 |
+
default="all-mpnet-base-v2",
|
99 |
+
help="Embedding model to use for question encoding and semantic search",
|
100 |
+
)
|
101 |
+
|
102 |
+
main_args, _ = parser.parse_known_args()
|
103 |
+
generate_dpr_training_file(main_args)
|
util/create_dpr_training_from_faiss.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import load_dataset
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
+
from transformers import DPRQuestionEncoder
|
9 |
+
|
10 |
+
from common import embed_questions, clean_question, articles_to_paragraphs, kilt_wikipedia_columns
|
11 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
12 |
+
|
13 |
+
|
14 |
+
def generate_dpr_training_file(args):
|
15 |
+
n_negatives = 7
|
16 |
+
min_chars_per_passage = 200
|
17 |
+
|
18 |
+
def query_index(question, topk=(n_negatives * args.n_positives) * 2):
|
19 |
+
question_embedding = embed_questions(question_model, question_tokenizer, [question])
|
20 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
21 |
+
|
22 |
+
retrieved_examples = []
|
23 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
24 |
+
for i in range(topk):
|
25 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
26 |
+
|
27 |
+
return retrieved_examples
|
28 |
+
|
29 |
+
def find_positive_and_hard_negative_ctxs(dataset_index: int, n_positive=1, device="cuda:0"):
|
30 |
+
positive_context_list = []
|
31 |
+
hard_negative_context_list = []
|
32 |
+
example = dataset[dataset_index]
|
33 |
+
question = clean_question(example['title'])
|
34 |
+
passages = query_index(question)
|
35 |
+
passages = [dict([(k, p[k]) for k in columns]) for p in passages]
|
36 |
+
q_passage_pairs = [[question, f"{p['title']} {p['text']}" if args.use_title else p["text"]] for p in passages]
|
37 |
+
|
38 |
+
features = ce_tokenizer(q_passage_pairs, padding="max_length", max_length=256, truncation=True,
|
39 |
+
return_tensors="pt")
|
40 |
+
with torch.no_grad():
|
41 |
+
passage_scores = ce_model(features["input_ids"].to(device),
|
42 |
+
features["attention_mask"].to(device)).logits
|
43 |
+
|
44 |
+
for p_idx, p in enumerate(passages):
|
45 |
+
p["score"] = passage_scores[p_idx].item()
|
46 |
+
|
47 |
+
# order by scores
|
48 |
+
def score_passage(item):
|
49 |
+
return item["score"]
|
50 |
+
|
51 |
+
# pick the most relevant as the positive answer
|
52 |
+
best_passage_list = sorted(passages, key=score_passage, reverse=True)
|
53 |
+
for idx, item in enumerate(best_passage_list):
|
54 |
+
if idx < n_positive:
|
55 |
+
positive_context_list.append({"title": item["title"], "text": item["text"]})
|
56 |
+
else:
|
57 |
+
break
|
58 |
+
|
59 |
+
# least relevant as hard_negative
|
60 |
+
worst_passage_list = sorted(passages, key=score_passage, reverse=False)
|
61 |
+
for idx, hard_negative in enumerate(worst_passage_list):
|
62 |
+
if idx < n_negatives * n_positive:
|
63 |
+
hard_negative_context_list.append({"title": hard_negative["title"], "text": hard_negative["text"]})
|
64 |
+
else:
|
65 |
+
break
|
66 |
+
assert len(positive_context_list) * n_negatives == len(hard_negative_context_list)
|
67 |
+
return positive_context_list, hard_negative_context_list
|
68 |
+
|
69 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
+
|
71 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
72 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
73 |
+
_ = question_model.eval()
|
74 |
+
|
75 |
+
ce_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2').to(device)
|
76 |
+
ce_tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2')
|
77 |
+
_ = ce_model.eval()
|
78 |
+
|
79 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
80 |
+
|
81 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
82 |
+
remove_columns=kilt_wikipedia_columns,
|
83 |
+
batch_size=512,
|
84 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
85 |
+
desc="Expanding wiki articles into paragraphs")
|
86 |
+
|
87 |
+
# use paragraphs that are not simple fragments or very short sentences
|
88 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
89 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
90 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
91 |
+
|
92 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
93 |
+
|
94 |
+
eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
|
95 |
+
eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
|
96 |
+
eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
|
97 |
+
|
98 |
+
for dataset_name, dataset in zip(["train", "validation", "test"], [eli5_train_set,
|
99 |
+
eli5_validation_set,
|
100 |
+
eli5_test_set]):
|
101 |
+
|
102 |
+
progress_bar = tqdm(range(len(dataset)), desc=f"Creating DPR formatted {dataset_name} file")
|
103 |
+
with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
|
104 |
+
for idx, example in enumerate(dataset):
|
105 |
+
negative_start_idx = 0
|
106 |
+
positive_context, hard_negative_ctxs = find_positive_and_hard_negative_ctxs(idx, args.n_positives,
|
107 |
+
device)
|
108 |
+
for pc in positive_context:
|
109 |
+
hnc = hard_negative_ctxs[negative_start_idx:negative_start_idx + n_negatives]
|
110 |
+
json.dump({"id": example["q_id"],
|
111 |
+
"question": clean_question(example["title"]),
|
112 |
+
"positive_ctxs": [pc],
|
113 |
+
"hard_negative_ctxs": hnc}, fp)
|
114 |
+
fp.write("\n")
|
115 |
+
negative_start_idx += n_negatives
|
116 |
+
progress_bar.update(1)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
parser = argparse.ArgumentParser(description="Creates DPR training file")
|
121 |
+
parser.add_argument(
|
122 |
+
"--use_title",
|
123 |
+
action="store_true",
|
124 |
+
help="If true, use title in addition to passage text for passage embedding",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--n_positives",
|
128 |
+
default=3,
|
129 |
+
help="Number of positive samples per question",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--question_encoder_name",
|
133 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
134 |
+
help="Question encoder to use",
|
135 |
+
)
|
136 |
+
|
137 |
+
parser.add_argument(
|
138 |
+
"--index_file_name",
|
139 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
140 |
+
help="Faiss index with passage embeddings",
|
141 |
+
)
|
142 |
+
|
143 |
+
main_args, _ = parser.parse_known_args()
|
144 |
+
generate_dpr_training_file(main_args)
|
util/create_faiss_index.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import faiss
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from transformers import AutoTokenizer, DPRContextEncoder
|
8 |
+
|
9 |
+
from common import articles_to_paragraphs, embed_passages
|
10 |
+
|
11 |
+
|
12 |
+
def create_faiss(args):
|
13 |
+
dims = 128
|
14 |
+
min_chars_per_passage = 200
|
15 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
|
18 |
+
ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
|
19 |
+
_ = ctx_model.eval()
|
20 |
+
|
21 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
22 |
+
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
23 |
+
'wikidata_info', 'history']
|
24 |
+
|
25 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
26 |
+
remove_columns=kilt_wikipedia_columns,
|
27 |
+
batch_size=512,
|
28 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
29 |
+
desc="Expanding wiki articles into paragraphs")
|
30 |
+
|
31 |
+
# use paragraphs that are not simple fragments or very short sentences
|
32 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
33 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
34 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
35 |
+
|
36 |
+
if not os.path.isfile(args.index_file_name):
|
37 |
+
def embed_passages_for_retrieval(examples):
|
38 |
+
return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
|
39 |
+
|
40 |
+
paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
|
41 |
+
batched=True, batch_size=512,
|
42 |
+
cache_file_name="../data/kilt_embedded.arrow",
|
43 |
+
desc="Creating faiss index")
|
44 |
+
|
45 |
+
paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
|
46 |
+
paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
|
47 |
+
else:
|
48 |
+
print(f"Faiss index already exists {args.index_file_name}")
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--ctx_encoder_name",
|
56 |
+
default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
|
57 |
+
help="Encoding model to use for passage encoding",
|
58 |
+
)
|
59 |
+
|
60 |
+
parser.add_argument(
|
61 |
+
"--index_file_name",
|
62 |
+
default="../data/kilt_dpr_wikipedia.faiss",
|
63 |
+
help="Faiss index file with passage embeddings",
|
64 |
+
)
|
65 |
+
|
66 |
+
main_args, _ = parser.parse_known_args()
|
67 |
+
create_faiss(main_args)
|
util/eval_generate.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder
|
9 |
+
|
10 |
+
from common import articles_to_paragraphs, kilt_wikipedia_columns
|
11 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
12 |
+
|
13 |
+
|
14 |
+
def eval_generate(args):
|
15 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
17 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
18 |
+
_ = question_model.eval()
|
19 |
+
|
20 |
+
eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
|
21 |
+
eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
|
22 |
+
_ = eli5_model.eval()
|
23 |
+
|
24 |
+
min_snippet_length = 20
|
25 |
+
topk = 21
|
26 |
+
min_chars_per_passage = 200
|
27 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
28 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
29 |
+
remove_columns=kilt_wikipedia_columns,
|
30 |
+
batch_size=256,
|
31 |
+
cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
|
32 |
+
desc="Expanding wiki articles into paragraphs")
|
33 |
+
|
34 |
+
# use paragraphs that are not simple fragments or very short sentences
|
35 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
36 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
37 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
38 |
+
|
39 |
+
def embed_questions_for_retrieval(questions):
|
40 |
+
query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
q_reps = question_model(query["input_ids"].to(device),
|
43 |
+
query["attention_mask"].to(device)).pooler_output
|
44 |
+
return q_reps.cpu().numpy()
|
45 |
+
|
46 |
+
def query_index(question):
|
47 |
+
question_embedding = embed_questions_for_retrieval([question])
|
48 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
49 |
+
|
50 |
+
retrieved_examples = []
|
51 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
52 |
+
for i in range(topk):
|
53 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
54 |
+
return retrieved_examples
|
55 |
+
|
56 |
+
def create_kilt_datapoint(q_id, query, answer, res_list):
|
57 |
+
# make a KILT data point
|
58 |
+
# see https://github.com/facebookresearch/KILT#kilt-data-format
|
59 |
+
|
60 |
+
provenance = [{
|
61 |
+
"wikipedia_id": r["wikipedia_id"], # *mandatory*
|
62 |
+
"title": r["title"],
|
63 |
+
"section": r["section"],
|
64 |
+
"start_paragraph_id": r["start_paragraph_id"],
|
65 |
+
"start_character": r["start_character"],
|
66 |
+
"end_paragraph_id": r["end_paragraph_id"],
|
67 |
+
"end_character": r["end_character"],
|
68 |
+
"text": r["text"],
|
69 |
+
"bleu_score": None, # wrt original evidence
|
70 |
+
"meta": None # dataset/task specific
|
71 |
+
} for r in res_list]
|
72 |
+
|
73 |
+
output = [{"answer": answer, "provenance": provenance}]
|
74 |
+
|
75 |
+
return {"id": q_id,
|
76 |
+
"input": query,
|
77 |
+
"output": output, # each element is an answer or provenance (can have multiple of each)
|
78 |
+
"meta": None # dataset/task specific
|
79 |
+
}
|
80 |
+
|
81 |
+
kilt_output = []
|
82 |
+
with open(args.kilt_input_file, "r") as f:
|
83 |
+
kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
|
84 |
+
progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
|
85 |
+
for idx, item in enumerate(kilt_items):
|
86 |
+
query = item["input"]
|
87 |
+
res_list = query_index(query)
|
88 |
+
|
89 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
|
90 |
+
documents = [res["text"] for res in res_list]
|
91 |
+
conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
|
92 |
+
|
93 |
+
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
|
94 |
+
|
95 |
+
model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
|
96 |
+
generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
|
97 |
+
attention_mask=model_input["attention_mask"].to(device),
|
98 |
+
min_length=50,
|
99 |
+
max_length=250,
|
100 |
+
do_sample=False,
|
101 |
+
early_stopping=True,
|
102 |
+
num_beams=8,
|
103 |
+
temperature=1.0,
|
104 |
+
top_k=None,
|
105 |
+
top_p=None,
|
106 |
+
no_repeat_ngram_size=3,
|
107 |
+
num_return_sequences=1)
|
108 |
+
answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
|
109 |
+
clean_up_tokenization_spaces=True)
|
110 |
+
|
111 |
+
kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
|
112 |
+
kilt_output.append(kilt_example)
|
113 |
+
progress_bar.update(1)
|
114 |
+
|
115 |
+
with open(args.kilt_output_file, "w") as fp:
|
116 |
+
for kilt_example in kilt_output:
|
117 |
+
json.dump(kilt_example, fp)
|
118 |
+
fp.write("\n")
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
|
124 |
+
parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
|
125 |
+
parser.add_argument(
|
126 |
+
"--question_encoder_name",
|
127 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
128 |
+
help="Question encoder to use",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--index_file_name",
|
133 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
134 |
+
help="Faiss index with passage embeddings",
|
135 |
+
)
|
136 |
+
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
|
140 |
+
eval_generate(args)
|
util/kilt_create_dpr_support_docs.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import faiss
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset, Dataset
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
|
10 |
+
|
11 |
+
from common import articles_to_paragraphs, embed_questions, embed_passages, create_kilt_datapoint, \
|
12 |
+
kilt_wikipedia_columns
|
13 |
+
from common import kilt_wikipedia_paragraph_columns as columns
|
14 |
+
|
15 |
+
|
16 |
+
def generate_support_docs(args):
|
17 |
+
dims = 128
|
18 |
+
min_chars_per_passage = 200
|
19 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
lfqa = load_dataset("vblagoje/lfqa")
|
21 |
+
|
22 |
+
ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
|
23 |
+
ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
|
24 |
+
_ = ctx_model.eval()
|
25 |
+
|
26 |
+
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
|
27 |
+
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
|
28 |
+
_ = question_model.eval()
|
29 |
+
|
30 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
31 |
+
|
32 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
33 |
+
remove_columns=kilt_wikipedia_columns,
|
34 |
+
batch_size=512,
|
35 |
+
cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
|
36 |
+
desc="Expanding wiki articles into paragraphs")
|
37 |
+
|
38 |
+
# use paragraphs that are not simple fragments or very short sentences
|
39 |
+
# Wikipedia Faiss index needs to fit into a 16 Gb GPU
|
40 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
|
41 |
+
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
|
42 |
+
|
43 |
+
def query_index(question, topk=7):
|
44 |
+
topk = topk * 3 # grab 3x results and filter for word count
|
45 |
+
question_embedding = embed_questions(question_model, question_tokenizer, [question])
|
46 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
47 |
+
|
48 |
+
retrieved_examples = []
|
49 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
50 |
+
for i in range(topk):
|
51 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
52 |
+
|
53 |
+
return retrieved_examples
|
54 |
+
|
55 |
+
def create_support_doc(dataset: Dataset, output_filename: str):
|
56 |
+
progress_bar = tqdm(range(len(dataset)), desc="Creating supporting docs")
|
57 |
+
|
58 |
+
with open(output_filename, "w") as fp:
|
59 |
+
for example in dataset:
|
60 |
+
wiki_passages = query_index(example["title"])
|
61 |
+
kilt_dp = create_kilt_datapoint(example, columns, wiki_passages)
|
62 |
+
json.dump(kilt_dp, fp)
|
63 |
+
fp.write("\n")
|
64 |
+
progress_bar.update(1)
|
65 |
+
|
66 |
+
if not os.path.isfile(args.index_file_name):
|
67 |
+
def embed_passages_for_retrieval(examples):
|
68 |
+
return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
|
69 |
+
|
70 |
+
paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
|
71 |
+
batched=True, batch_size=512,
|
72 |
+
cache_file_name=args.encoded_kilt_file_name,
|
73 |
+
desc="Creating faiss index")
|
74 |
+
|
75 |
+
paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
|
76 |
+
paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
|
77 |
+
|
78 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
|
79 |
+
create_support_doc(lfqa["train"], "lfqa_dpr_train_precomputed_dense_docs.json")
|
80 |
+
create_support_doc(lfqa["validation"], "lfqa_dpr_validation_precomputed_dense_docs.json")
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
parser = argparse.ArgumentParser(description="Creates support docs for seq2seq model training")
|
85 |
+
parser.add_argument(
|
86 |
+
"--ctx_encoder_name",
|
87 |
+
default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
|
88 |
+
help="Question encoder to use",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--question_encoder_name",
|
92 |
+
default="vblagoje/dpr-question_encoder-single-lfqa-base",
|
93 |
+
help="Question encoder to use",
|
94 |
+
)
|
95 |
+
|
96 |
+
parser.add_argument(
|
97 |
+
"--index_file_name",
|
98 |
+
default="../data/kilt_dpr_wikipedia_first.faiss",
|
99 |
+
help="Faiss index with passage embeddings",
|
100 |
+
)
|
101 |
+
|
102 |
+
parser.add_argument(
|
103 |
+
"--encoded_kilt_file_name",
|
104 |
+
default="../data/kilt_embedded.arrow",
|
105 |
+
help="Encoded KILT file name",
|
106 |
+
)
|
107 |
+
|
108 |
+
main_args, _ = parser.parse_known_args()
|
109 |
+
generate_support_docs(main_args)
|
util/query_smoke_test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
device = ("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained('vblagoje/retribert-base-uncased')
|
10 |
+
model = AutoModel.from_pretrained('vblagoje/retribert-base-uncased').to(device)
|
11 |
+
_ = model.eval()
|
12 |
+
|
13 |
+
index_file_name = "./data/kilt_wikipedia.faiss"
|
14 |
+
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
|
15 |
+
columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
|
16 |
+
'wikidata_info', 'history']
|
17 |
+
|
18 |
+
min_snippet_length = 20
|
19 |
+
topk = 21
|
20 |
+
|
21 |
+
def articles_to_paragraphs(examples):
|
22 |
+
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
|
23 |
+
for bidx, example in enumerate(examples["text"]):
|
24 |
+
last_section = ""
|
25 |
+
for idx, p in enumerate(example["paragraph"]):
|
26 |
+
if "Section::::" in p:
|
27 |
+
last_section = p
|
28 |
+
ids.append(examples["wikipedia_id"][bidx])
|
29 |
+
titles.append(examples["wikipedia_title"][bidx])
|
30 |
+
sections.append(last_section)
|
31 |
+
texts.append(p)
|
32 |
+
start_ps.append(idx)
|
33 |
+
end_ps.append(idx)
|
34 |
+
start_cs.append(0)
|
35 |
+
end_cs.append(len(p))
|
36 |
+
|
37 |
+
return {"wikipedia_id": ids, "title": titles,
|
38 |
+
"section": sections, "text": texts,
|
39 |
+
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
|
40 |
+
"start_character": start_cs,
|
41 |
+
"end_character": end_cs
|
42 |
+
}
|
43 |
+
|
44 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
|
45 |
+
remove_columns=columns,
|
46 |
+
batch_size=256, cache_file_name=f"./wiki_kilt_paragraphs_full.arrow",
|
47 |
+
desc="Expanding wiki articles into paragraphs")
|
48 |
+
|
49 |
+
# use paragraphs that are not simple fragments or very short sentences
|
50 |
+
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 250)
|
51 |
+
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
|
52 |
+
|
53 |
+
def embed_questions_for_retrieval(questions):
|
54 |
+
query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
|
55 |
+
with torch.no_grad():
|
56 |
+
q_reps = model.embed_questions(query["input_ids"].to(device),
|
57 |
+
query["attention_mask"].to(device)).cpu().type(torch.float)
|
58 |
+
return q_reps.numpy()
|
59 |
+
|
60 |
+
def query_index(question):
|
61 |
+
question_embedding = embed_questions_for_retrieval([question])
|
62 |
+
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
|
63 |
+
columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id', 'start_character','end_character']
|
64 |
+
retrieved_examples = []
|
65 |
+
r = list(zip(wiki_passages[k] for k in columns))
|
66 |
+
for i in range(topk):
|
67 |
+
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
|
68 |
+
return retrieved_examples
|
69 |
+
|
70 |
+
questions = ["What causes the contrails (cirrus aviaticus) behind jets at high altitude? ",
|
71 |
+
"Why does water heated to a room temeperature feel colder than the air around it?"]
|
72 |
+
res_list = query_index(questions[0])
|
73 |
+
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
|
74 |
+
for res in res_list:
|
75 |
+
print("\n")
|
76 |
+
print(res)
|
77 |
+
|
78 |
+
|
79 |
+
main()
|
80 |
+
|
81 |
+
|
wikipedia_answer.png
ADDED
![]() |
wikipedia_context.png
ADDED
![]() |