king007 danijelpetkovic commited on
Commit
039aebb
·
0 Parent(s):

Duplicate from deepset/wikipedia-assistant

Browse files

Co-authored-by: Danijel Petkovic <danijelpetkovic@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # pytype static type analyzer
142
+ .pytype/
143
+
144
+ # Cython debug symbols
145
+ cython_debug/
146
+
147
+ # PyCharm
148
+ # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
149
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
150
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
151
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152
+ .idea/
153
+ out.flac
154
+ out.mp3
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [logger]
2
+ level = "debug"
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Wikipedia Assistant
3
+ emoji: 🌖
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.9.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: deepset/wikipedia-assistant
11
+ ---
12
+
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: `streamlit`
28
+ Can be either `gradio` or `streamlit`
29
+
30
+ `sdk_version` : `1.9.0`
31
+ Only applicable for `streamlit` SDK.
32
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
+
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
36
+ Path is relative to the root of the repository.
37
+
38
+ `pinned`: _boolean_
39
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import streamlit as st
4
+ from multipage import MultiPage
5
+ from pages import ask, settings, info
6
+
7
+
8
+ logging.basicConfig(
9
+ level=logging.DEBUG,
10
+ format="%(levelname)s %(asctime)s %(name)s:%(message)s",
11
+ handlers=[logging.StreamHandler(sys.stdout)],
12
+ force=True,
13
+ )
14
+
15
+
16
+ def init_session_key_value(key, value):
17
+ if key not in st.session_state:
18
+ st.session_state[key] = value
19
+
20
+
21
+ lfqa_api = "HuggingFace" if "api_lfqa_selector" not in st.secrets else st.secrets["api_lfqa_selector"]
22
+ session_values = {"api_lfqa_selector": lfqa_api,
23
+ "tts": "Google",
24
+ "min_length": 64,
25
+ "max_length": 256,
26
+ "do_sample": False,
27
+ "early_stopping": True,
28
+ "num_beams": 8,
29
+ "temperature": 1.0,
30
+ "top_k": None,
31
+ "top_p": None,
32
+ "no_repeat_ngram_size": 3,
33
+ "num_return_sequences": 1}
34
+
35
+ for k, v in session_values.items():
36
+ init_session_key_value(k, v)
37
+
38
+ app = MultiPage()
39
+ st.set_page_config(
40
+ page_title="Wikipedia Assistant",
41
+ initial_sidebar_state="expanded",
42
+ )
43
+ # Add all your application here
44
+ app.add_page("Home", "house", ask.app)
45
+ app.add_page("Settings", "gear", settings.app)
46
+ app.add_page("Info", "info", info.app)
47
+
48
+ # The main app
49
+ app.run()
context_server/Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
2
+ #set up environment
3
+ RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
4
+ RUN apt-get install unzip
5
+ RUN apt-get -y install python3
6
+ RUN apt-get -y install python3-pip
7
+
8
+ WORKDIR /code
9
+
10
+ ENV HF_HOME=/code/cache
11
+
12
+ COPY ./requirements.txt /code/requirements.txt
13
+
14
+ RUN pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
15
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
16
+
17
+ COPY ./main.py /code/app/main.py
18
+
19
+ COPY ./data/kilt_wiki_prepared/ /code/data/kilt_wiki_prepared
20
+
21
+ COPY ./data/kilt_wikipedia.faiss /code/data/kilt_wikipedia.faiss
22
+
23
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
context_server/__init__.py ADDED
File without changes
context_server/main.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Depends, status
3
+ from fastapi.responses import PlainTextResponse
4
+ from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder
5
+
6
+ from datasets import load_from_disk
7
+ import time
8
+ from typing import Dict
9
+
10
+ import jwt
11
+ from decouple import config
12
+ from fastapi import Request, HTTPException
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+
15
+ JWT_SECRET = config("secret")
16
+ JWT_ALGORITHM = config("algorithm")
17
+
18
+ app = FastAPI()
19
+ app.ready = False
20
+ columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
21
+ 'wikidata_info', 'history']
22
+
23
+ min_snippet_length = 20
24
+ topk = 21
25
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
26
+ model = DPRQuestionEncoder.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki").to(device)
27
+ tokenizer = AutoTokenizer.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki")
28
+ _ = model.eval()
29
+
30
+ index_file_name = "./data/kilt_wikipedia.faiss"
31
+
32
+ kilt_wikipedia_paragraphs = load_from_disk("./data/kilt_wiki_prepared")
33
+ # use paragraphs that are not simple fragments or very short sentences
34
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 200)
35
+
36
+
37
+ class JWTBearer(HTTPBearer):
38
+ def __init__(self, auto_error: bool = True):
39
+ super(JWTBearer, self).__init__(auto_error=auto_error)
40
+
41
+ async def __call__(self, request: Request):
42
+ credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
43
+ if credentials:
44
+ if not credentials.scheme == "Bearer":
45
+ raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
46
+ if not self.verify_jwt(credentials.credentials):
47
+ raise HTTPException(status_code=403, detail="Invalid token or expired token.")
48
+ return credentials.credentials
49
+ else:
50
+ raise HTTPException(status_code=403, detail="Invalid authorization code.")
51
+
52
+ def verify_jwt(self, jwtoken: str) -> bool:
53
+ isTokenValid: bool = False
54
+
55
+ try:
56
+ payload = decodeJWT(jwtoken)
57
+ except:
58
+ payload = None
59
+ if payload:
60
+ isTokenValid = True
61
+ return isTokenValid
62
+
63
+
64
+ def token_response(token: str):
65
+ return {
66
+ "access_token": token
67
+ }
68
+
69
+
70
+ def signJWT(user_id: str) -> Dict[str, str]:
71
+ payload = {
72
+ "user_id": user_id,
73
+ "expires": time.time() + 6000
74
+ }
75
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
76
+
77
+ return token_response(token)
78
+
79
+
80
+ def decodeJWT(token: str) -> dict:
81
+ try:
82
+ decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
83
+ return decoded_token if decoded_token["expires"] >= time.time() else None
84
+ except:
85
+ return {}
86
+
87
+
88
+ def embed_questions_for_retrieval(questions):
89
+ query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
90
+ with torch.no_grad():
91
+ q_reps = model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output
92
+ return q_reps.cpu().numpy()
93
+
94
+ def query_index(question):
95
+ question_embedding = embed_questions_for_retrieval([question])
96
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
97
+ columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id',
98
+ 'start_character', 'end_character']
99
+ retrieved_examples = []
100
+ r = list(zip(wiki_passages[k] for k in columns))
101
+ for i in range(topk):
102
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
103
+ return retrieved_examples
104
+
105
+
106
+ @app.on_event("startup")
107
+ def startup():
108
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
109
+ app.ready = True
110
+
111
+
112
+ @app.get("/healthz")
113
+ def healthz():
114
+ if app.ready:
115
+ return PlainTextResponse("ok")
116
+ return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
117
+
118
+
119
+ @app.get("/find_context", dependencies=[Depends(JWTBearer())])
120
+ def find_context(question: str = None):
121
+ return [res for res in query_index(question) if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
122
+
context_server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ fastapi
4
+ faiss-gpu
5
+ uvicorn[standard]
6
+ PyJWT==1.7.1
7
+ python-decouple==3.3
lfqa.png ADDED
lfqa_server/Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
2
+ #set up environment
3
+ RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
4
+ RUN apt-get install unzip
5
+ RUN apt-get -y install python3
6
+ RUN apt-get -y install python3-pip
7
+
8
+ WORKDIR /code
9
+
10
+ ENV HF_HOME=/code/cache
11
+
12
+ COPY ./requirements.txt /code/requirements.txt
13
+
14
+ RUN pip3 install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
15
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
16
+
17
+ COPY ./main.py /code/app/main.py
18
+
19
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
lfqa_server/__init__.py ADDED
File without changes
lfqa_server/main.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Depends, status
3
+ from fastapi.responses import PlainTextResponse
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ import time
8
+ from typing import Dict, List, Optional
9
+
10
+ import jwt
11
+ from decouple import config
12
+ from fastapi import Request, HTTPException
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+
15
+ JWT_SECRET = config("secret")
16
+ JWT_ALGORITHM = config("algorithm")
17
+
18
+ app = FastAPI()
19
+ app.ready = False
20
+
21
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
22
+ tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
23
+ model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
24
+ _ = model.eval()
25
+
26
+
27
+ class JWTBearer(HTTPBearer):
28
+ def __init__(self, auto_error: bool = True):
29
+ super(JWTBearer, self).__init__(auto_error=auto_error)
30
+
31
+ async def __call__(self, request: Request):
32
+ credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
33
+ if credentials:
34
+ if not credentials.scheme == "Bearer":
35
+ raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
36
+ if not self.verify_jwt(credentials.credentials):
37
+ raise HTTPException(status_code=403, detail="Invalid token or expired token.")
38
+ return credentials.credentials
39
+ else:
40
+ raise HTTPException(status_code=403, detail="Invalid authorization code.")
41
+
42
+ def verify_jwt(self, jwtoken: str) -> bool:
43
+ isTokenValid: bool = False
44
+
45
+ try:
46
+ payload = decodeJWT(jwtoken)
47
+ except:
48
+ payload = None
49
+ if payload:
50
+ isTokenValid = True
51
+ return isTokenValid
52
+
53
+
54
+ def token_response(token: str):
55
+ return {
56
+ "access_token": token
57
+ }
58
+
59
+
60
+ def signJWT(user_id: str) -> Dict[str, str]:
61
+ payload = {
62
+ "user_id": user_id,
63
+ "expires": time.time() + 6000
64
+ }
65
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
66
+
67
+ return token_response(token)
68
+
69
+
70
+ def decodeJWT(token: str) -> dict:
71
+ try:
72
+ decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
73
+ return decoded_token if decoded_token["expires"] >= time.time() else None
74
+ except:
75
+ return {}
76
+
77
+
78
+ class LFQAParameters(BaseModel):
79
+ min_length: int = 50
80
+ max_length: int = 250
81
+ do_sample: bool = False
82
+ early_stopping: bool = True
83
+ num_beams: int = 8
84
+ temperature: float = 1.0
85
+ top_k: float = None
86
+ top_p: float = None
87
+ no_repeat_ngram_size: int = 3
88
+ num_return_sequences: int = 1
89
+
90
+
91
+ class InferencePayload(BaseModel):
92
+ model_input: str
93
+ parameters: Optional[LFQAParameters] = LFQAParameters()
94
+
95
+
96
+ @app.on_event("startup")
97
+ def startup():
98
+ app.ready = True
99
+
100
+
101
+ @app.get("/healthz")
102
+ def healthz():
103
+ if app.ready:
104
+ return PlainTextResponse("ok")
105
+ return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
106
+
107
+
108
+ @app.post("/generate/", dependencies=[Depends(JWTBearer())])
109
+ def generate(context: InferencePayload):
110
+
111
+ model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
112
+ param = context.parameters
113
+ generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
114
+ attention_mask=model_input["attention_mask"].to(device),
115
+ min_length=param.min_length,
116
+ max_length=param.max_length,
117
+ do_sample=param.do_sample,
118
+ early_stopping=param.early_stopping,
119
+ num_beams=param.num_beams,
120
+ temperature=param.temperature,
121
+ top_k=param.top_k,
122
+ top_p=param.top_p,
123
+ no_repeat_ngram_size=param.no_repeat_ngram_size,
124
+ num_return_sequences=param.num_return_sequences)
125
+ answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
126
+ clean_up_tokenization_spaces=True)
127
+ results = []
128
+ for answer in answers:
129
+ results.append({"generated_text": answer})
130
+ return results
lfqa_server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ fastapi
4
+ faiss-gpu
5
+ uvicorn[standard]
6
+ PyJWT==1.7.1
7
+ python-decouple==3.3
multipage.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is the framework for generating multiple Streamlit applications
3
+ through an object oriented framework.
4
+ """
5
+
6
+ # Import necessary libraries
7
+ import streamlit as st
8
+ from streamlit_option_menu import option_menu
9
+
10
+
11
+ # Define the multipage class to manage the multiple apps in our program
12
+ class MultiPage:
13
+ """Framework for combining multiple streamlit applications."""
14
+
15
+ def __init__(self) -> None:
16
+ """Constructor class to generate a list which will store all our applications as an instance variable."""
17
+ self.pages = []
18
+
19
+ def add_page(self, title, icon, func) -> None:
20
+ """Class Method to Add pages to the project
21
+
22
+ Args:
23
+ title ([str]): The title of page which we are adding to the list of apps
24
+
25
+ func: Python function to render this page in Streamlit
26
+ """
27
+
28
+ self.pages.append(
29
+ {
30
+ "title": title,
31
+ "icon": icon,
32
+ "function": func
33
+ }
34
+ )
35
+
36
+ def run(self):
37
+ # Drodown to select the page to run
38
+ st.markdown("""
39
+ <style>
40
+ section[data-testid="stSidebar"] > div:first-of-type {
41
+ background-color: var(--secondary-background-color);
42
+ background: var(--secondary-background-color);
43
+ width: 250px;
44
+ padding: 4rem 0;
45
+ box-shadow: -2rem 0px 2rem 2rem rgba(0,0,0,0.16);
46
+ }
47
+ section[aria-expanded="true"] > div:nth-of-type(2) {
48
+ display: none;
49
+ }
50
+ .main > div:first-of-type {
51
+ padding: 1rem 0;
52
+ }
53
+ </style>
54
+ """, unsafe_allow_html=True)
55
+
56
+ with st.sidebar:
57
+ selected = option_menu(None, [page["title"] for page in self.pages],
58
+ icons=[page["icon"] for page in self.pages],
59
+ menu_icon="cast", default_index=0)
60
+
61
+ # Run the selected page
62
+ for index, item in enumerate(self.pages):
63
+ if item["title"] == selected:
64
+ self.pages[index]["function"]()
65
+ break
pages/ask.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import colorsys
3
+ import json
4
+ import re
5
+ import time
6
+
7
+ import nltk
8
+ import numpy as np
9
+ from nltk import tokenize
10
+
11
+ nltk.download('punkt')
12
+ from google.oauth2 import service_account
13
+ from google.cloud import texttospeech
14
+
15
+ from typing import Dict, Optional, List
16
+
17
+ import jwt
18
+ import requests
19
+ import streamlit as st
20
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
21
+
22
+ JWT_SECRET = st.secrets["api_secret"]
23
+ JWT_ALGORITHM = st.secrets["api_algorithm"]
24
+ INFERENCE_TOKEN = st.secrets["api_inference"]
25
+ CONTEXT_API_URL = st.secrets["api_context"]
26
+ LFQA_API_URL = st.secrets["api_lfqa"]
27
+
28
+ headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"}
29
+ API_URL = "https://api-inference.huggingface.co/models/askainet/bart_lfqa"
30
+ API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan"
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def api_inference_lfqa(model_input: str):
36
+ payload = {
37
+ "inputs": model_input,
38
+ "parameters": {
39
+ "truncation": "longest_first",
40
+ "min_length": st.session_state["min_length"],
41
+ "max_length": st.session_state["max_length"],
42
+ "do_sample": st.session_state["do_sample"],
43
+ "early_stopping": st.session_state["early_stopping"],
44
+ "num_beams": st.session_state["num_beams"],
45
+ "temperature": st.session_state["temperature"],
46
+ "top_k": None,
47
+ "top_p": None,
48
+ "no_repeat_ngram_size": 3,
49
+ "num_return_sequences": 1
50
+ },
51
+ "options": {
52
+ "wait_for_model": True
53
+ }
54
+ }
55
+ data = json.dumps(payload)
56
+ logger.debug(data)
57
+ response = requests.request("POST", API_URL, headers=headers, data=data)
58
+ return json.loads(response.content.decode("utf-8"))
59
+
60
+
61
+ def inference_lfqa(model_input: str, header: dict):
62
+ payload = {
63
+ "model_input": model_input,
64
+ "parameters": {
65
+ "min_length": st.session_state["min_length"],
66
+ "max_length": st.session_state["max_length"],
67
+ "do_sample": st.session_state["do_sample"],
68
+ "early_stopping": st.session_state["early_stopping"],
69
+ "num_beams": st.session_state["num_beams"],
70
+ "temperature": st.session_state["temperature"],
71
+ "top_k": None,
72
+ "top_p": None,
73
+ "no_repeat_ngram_size": 3,
74
+ "num_return_sequences": 1
75
+ }
76
+ }
77
+ data = json.dumps(payload)
78
+ try:
79
+ response = requests.request("POST", LFQA_API_URL, headers=header, data=data)
80
+ if response.status_code == 200:
81
+ json_response = response.content.decode("utf-8")
82
+ result = json.loads(json_response)
83
+ else:
84
+ result = {"error": f"LFQA service unavailable, status code={response.status_code}"}
85
+ except requests.exceptions.RequestException as e:
86
+ result = {"error": e}
87
+ return result
88
+
89
+
90
+ def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]):
91
+ if "HuggingFace" == service_backend:
92
+ inference_response = api_inference_lfqa(model_input)
93
+ else:
94
+ inference_response = inference_lfqa(model_input, header)
95
+ return inference_response
96
+
97
+
98
+ @st.cache(allow_output_mutation=True, show_spinner=False)
99
+ def hf_tts(text: str):
100
+ payload = {
101
+ "inputs": text,
102
+ "parameters": {
103
+ "vocoder_tag": "str_or_none(none)",
104
+ "threshold": 0.5,
105
+ "minlenratio": 0.0,
106
+ "maxlenratio": 10.0,
107
+ "use_att_constraint": False,
108
+ "backward_window": 1,
109
+ "forward_window": 3,
110
+ "speed_control_alpha": 1.0,
111
+ "noise_scale": 0.333,
112
+ "noise_scale_dur": 0.333
113
+ },
114
+ "options": {
115
+ "wait_for_model": True
116
+ }
117
+ }
118
+ data = json.dumps(payload)
119
+ response = requests.request("POST", API_URL_TTS, headers=headers, data=data)
120
+ return response.content
121
+
122
+
123
+ @st.cache(allow_output_mutation=True, show_spinner=False)
124
+ def google_tts(text: str, private_key_id: str, private_key: str, client_email: str):
125
+ config = {
126
+ "private_key_id": private_key_id,
127
+ "private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n",
128
+ "client_email": client_email,
129
+ "token_uri": "https://oauth2.googleapis.com/token",
130
+ }
131
+ credentials = service_account.Credentials.from_service_account_info(config)
132
+ client = texttospeech.TextToSpeechClient(credentials=credentials)
133
+
134
+ synthesis_input = texttospeech.SynthesisInput(text=text)
135
+
136
+ # Build the voice request, select the language code ("en-US") and the ssml
137
+ # voice gender ("neutral")
138
+ voice = texttospeech.VoiceSelectionParams(language_code="en-US",
139
+ ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL)
140
+
141
+ # Select the type of audio file you want returned
142
+ audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3)
143
+
144
+ # Perform the text-to-speech request on the text input with the selected
145
+ # voice parameters and audio file type
146
+ response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config)
147
+ return response
148
+
149
+
150
+ def request_context_passages(question, header):
151
+ try:
152
+ response = requests.request("GET", CONTEXT_API_URL + question, headers=header)
153
+ if response.status_code == 200:
154
+ json_response = response.content.decode("utf-8")
155
+ result = json.loads(json_response)
156
+ else:
157
+ result = {"error": f"Context passage service unavailable, status code={response.status_code}"}
158
+ except requests.exceptions.RequestException as e:
159
+ result = {"error": e}
160
+
161
+ return result
162
+
163
+
164
+ @st.cache(allow_output_mutation=True, show_spinner=False)
165
+ def get_sentence_transformer():
166
+ return SentenceTransformer('all-MiniLM-L6-v2')
167
+
168
+
169
+ @st.cache(allow_output_mutation=True, show_spinner=False)
170
+ def get_sentence_transformer_encoding(sentences):
171
+ model = get_sentence_transformer()
172
+ return model.encode([sentence for sentence in sentences], convert_to_tensor=True)
173
+
174
+
175
+ def sign_jwt() -> Dict[str, str]:
176
+ payload = {
177
+ "expires": time.time() + 6000
178
+ }
179
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
180
+ return token
181
+
182
+
183
+ def extract_sentences_from_passages(passages):
184
+ sentences = []
185
+ for idx, node in enumerate(passages):
186
+ sentences.extend(tokenize.sent_tokenize(node["text"]))
187
+ return sentences
188
+
189
+
190
+ def similarity_color_picker(similarity: float):
191
+ value = int(similarity * 75)
192
+ rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0)
193
+ return [round(255 * x) for x in rgb]
194
+
195
+
196
+ def rgb_to_hex(rgb):
197
+ return '%02x%02x%02x' % tuple(rgb)
198
+
199
+
200
+ def similiarity_to_hex(similarity: float):
201
+ return rgb_to_hex(similarity_color_picker(similarity))
202
+
203
+
204
+ def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]:
205
+ ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
206
+ question_passage_combinations = [[question, p["text"]] for p in passages]
207
+
208
+ # Compute the similarity scores for these combinations
209
+ similarity_scores = ce.predict(question_passage_combinations)
210
+
211
+ # Sort the scores in decreasing order
212
+ sim_ranking_idx = np.flip(np.argsort(similarity_scores))
213
+ return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]]
214
+
215
+
216
+ def answer_to_context_similarity(generated_answer, context_passages, topk=3):
217
+ context_sentences = extract_sentences_from_passages(context_passages)
218
+ context_sentences_e = get_sentence_transformer_encoding(context_sentences)
219
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
220
+ answer_sentences_e = get_sentence_transformer_encoding(answer_sentences)
221
+ search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk)
222
+ result = []
223
+ for idx, r in enumerate(search_result):
224
+ context = []
225
+ for idx_c in range(topk):
226
+ context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]})
227
+ result.append({"answer": answer_sentences[idx], "context": context})
228
+ return result
229
+
230
+
231
+ def post_process_answer(generated_answer):
232
+ result = generated_answer
233
+ # detect sentence boundaries regex pattern
234
+ regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))"
235
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
236
+ # do we have truncated last sentence?
237
+ if len(answer_sentences) > len(re.findall(regex, generated_answer)):
238
+ drop_last_sentence = " ".join(s for s in answer_sentences[:-1])
239
+ result = drop_last_sentence
240
+ return result.strip()
241
+
242
+
243
+ def format_score(value: float, precision=2):
244
+ return f"{value:.{precision}f}"
245
+
246
+
247
+ @st.cache(allow_output_mutation=True, show_spinner=False)
248
+ def get_answer(question: str):
249
+ if not question:
250
+ return {}
251
+
252
+ resp: Dict[str, str] = {}
253
+ if question and len(question.split()) > 3:
254
+ header = {"Authorization": f"Bearer {sign_jwt()}"}
255
+ context_passages = request_context_passages(question, header)
256
+ if "error" in context_passages:
257
+ resp = context_passages
258
+ else:
259
+ context_passages = rerank(question, context_passages)
260
+ conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages])
261
+ model_input = f'question: {question} context: {conditioned_context}'
262
+
263
+ inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header)
264
+ if "error" in inference_response:
265
+ resp = inference_response
266
+ else:
267
+ resp["context_passages"] = context_passages
268
+ resp["answer"] = post_process_answer(inference_response[0]["generated_text"])
269
+ else:
270
+ resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."}
271
+ return resp
272
+
273
+
274
+ def app():
275
+ with open('style.css') as f:
276
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
277
+ footer = """
278
+ <div class="footer-custom">
279
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
280
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
281
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> |
282
+ <a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a>
283
+ </div>
284
+ """
285
+ st.markdown(footer, unsafe_allow_html=True)
286
+
287
+ st.title('Wikipedia Assistant')
288
+
289
+ question = st.text_input(
290
+ label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"')
291
+
292
+ spinner = st.empty()
293
+ if question !="":
294
+ spinner.markdown(
295
+ f"""
296
+ <div class="loader-wrapper">
297
+ <div class="loader">
298
+ </div>
299
+ <p>Generating answer for: <b>{question}</b></p>
300
+ </div>
301
+ <label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label>
302
+ """,
303
+ unsafe_allow_html=True,
304
+ )
305
+
306
+ question_response = get_answer(question)
307
+ if question_response:
308
+ if "error" in question_response:
309
+ st.warning(question_response["error"])
310
+ else:
311
+ spinner.markdown(f"")
312
+ generated_answer = question_response["answer"]
313
+ context_passages = question_response["context_passages"]
314
+ sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3)
315
+ sentences = "<div class='sentence-wrapper'>"
316
+ for item in sentence_similarity:
317
+ sentences += '<span>'
318
+ score = item["context"][0]["score"]
319
+ support_sentence = item["context"][0]["source"]
320
+ sentences += "".join([
321
+ f' {item["answer"]}',
322
+ f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">',
323
+ f'{format_score(score, precision=1)}',
324
+ f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>'
325
+ ])
326
+ sentences += '</span>'
327
+ sentences += '</span>'
328
+ st.markdown(sentences, unsafe_allow_html=True)
329
+
330
+ with st.spinner("Generating audio..."):
331
+ if st.session_state["tts"] == "HuggingFace":
332
+ audio_file = hf_tts(generated_answer)
333
+ with open("out.flac", "wb") as f:
334
+ f.write(audio_file)
335
+ else:
336
+ audio_file = google_tts(generated_answer, st.secrets["private_key_id"],
337
+ st.secrets["private_key"], st.secrets["client_email"])
338
+ with open("out.mp3", "wb") as f:
339
+ f.write(audio_file.audio_content)
340
+
341
+ audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3"
342
+ st.audio(audio_file)
343
+
344
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
345
+
346
+ model = get_sentence_transformer()
347
+
348
+ col1, col2 = st.columns(2)
349
+
350
+ with col1:
351
+ st.subheader("Context")
352
+ with col2:
353
+ selection = st.selectbox(
354
+ label="",
355
+ options=('Paragraphs', 'Sentences', 'Answer Similarity'),
356
+ help="Context represents Wikipedia passages used to generate the answer")
357
+ question_e = model.encode(question, convert_to_tensor=True)
358
+ if selection == "Paragraphs":
359
+ sentences = extract_sentences_from_passages(context_passages)
360
+ context_e = get_sentence_transformer_encoding(sentences)
361
+ scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e)
362
+ similarity_scores = scores[0].squeeze().tolist()
363
+ for idx, node in enumerate(context_passages):
364
+ node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx])
365
+ context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True)
366
+ st.json(context_passages)
367
+ elif selection == "Sentences":
368
+ sentences = extract_sentences_from_passages(context_passages)
369
+ sentences_e = get_sentence_transformer_encoding(sentences)
370
+ scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e)
371
+ sentence_similarity_scores = scores[0].squeeze().tolist()
372
+ result = []
373
+ for idx, sentence in enumerate(sentences):
374
+ result.append(
375
+ {"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])})
376
+ context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True))
377
+ st.json(context_sentences)
378
+ else:
379
+ st.json(sentence_similarity)
pages/info.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def app():
5
+ with open('style.css') as f:
6
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
7
+ footer = """
8
+ <div class="footer-custom">
9
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
10
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
11
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
12
+ </div>
13
+ """
14
+ st.markdown(footer, unsafe_allow_html=True)
15
+
16
+ st.subheader("Intro")
17
+ intro = """
18
+ <div class="text">
19
+ Wikipedia Assistant is an example of a task usually referred to as the Long-Form Question Answering (LFQA).
20
+ These systems function by querying large document stores for relevant information and subsequently using
21
+ the retrieved documents to generate accurate, multi-sentence answers. The documents related to a given
22
+ query, colloquially called context passages, are not used merely as source tokens for extracted answers,
23
+ but instead provide a larger context for the synthesis of original, abstractive long-form answers.
24
+ LFQA systems usually consist of three components:
25
+ <ul>
26
+ <li>A document store including content passages for a variety of topics</li>
27
+ <li>Encoder models to encode documents/questions such that it is possible to query the document store</li>
28
+ <li>A Seq2Seq language model capable of generating paragraph-long answers when given a question and
29
+ context passages retrieved from the document store</li>
30
+ </ul>
31
+ </div>
32
+ <br>
33
+ """
34
+ st.markdown(intro, unsafe_allow_html=True)
35
+ st.image("lfqa.png", caption="LFQA Architecture")
36
+ st.subheader("UI/UX")
37
+ st.write("Each sentence in the generated answer ends with a coloured tooltip; the colour ranges from red to green. "
38
+ "The tooltip contains a value representing answer sentence similarity to a specific sentence in the "
39
+ "Wikipedia context passages retrieved. Mouseover on the tooltip will show the sentence from the "
40
+ "Wikipedia context passage. If a sentence similarity is 1.0, the seq2seq model extracted and "
41
+ "copied the sentence verbatim from Wikipedia context passages. Lower values of sentence "
42
+ "similarity indicate the seq2seq model is struggling to generate a relevant sentence for the question "
43
+ "asked.")
44
+ st.image("wikipedia_answer.png", caption="Answer with similarity tooltips")
45
+ st.write("Below the generated answer are question-related Wikipedia context paragraphs (passages). One can view "
46
+ "these passages in a raw format retrieved using the 'Paragraphs' select menu option. The 'Sentences' menu "
47
+ "option shows the same paragraphs but on a sentence level. Finally, the 'Answer Similarity' menu option "
48
+ "shows the most similar three sentences from context paragraphs to each sentence in the generated answer.")
49
+ st.image("wikipedia_context.png", caption="Context paragraphs (passages)")
50
+
51
+ tts = """
52
+ <div class="text">
53
+ Wikipedia Assistant converts the text-based answer to speech via either Google text-to-speech engine or
54
+ <a href="https://github.com/espnet" target=_blank">Espnet model</a> hosted on
55
+ <a href="https://huggingface.co/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" target=_blank">
56
+ HuggingFace hub</a>
57
+ <br>
58
+ <br>
59
+ """
60
+ st.markdown(tts, unsafe_allow_html=True)
61
+
62
+ st.subheader("Tips")
63
+ tips = """
64
+ <div class="text">
65
+ LFQA task is far from solved. Wikipedia Assistant will sometimes generate an answer unrelated to a question asked,
66
+ even downright wrong. However, if the question is elaborate and more specific, there is a decent chance of
67
+ getting a legible answer. LFQA systems are targeting ELI5 non-factoid type of questions. A general guideline
68
+ is - questions starting with why, what, and how are better suited than where and who questions. Be elaborate.
69
+ <br><br>
70
+
71
+ For example, to ask a science-based question, Wikipedia Assistant is better suited to answer the question: "Why do
72
+ airplane jet engines leave contrails in the sky?" than "Why do contrails exist?". Detailed and precise questions
73
+ are more likely to match the right half a dozen relevant passages in a 20+ GB Wikipedia dump to construct a good
74
+ answer.
75
+ </div>
76
+ <br>
77
+ """
78
+ st.markdown(tips, unsafe_allow_html=True)
79
+ st.subheader("Technical details")
80
+ techinical_intro = """
81
+ <div class="text technical-details-info">
82
+ A question asked will be encoded with an <a href="https://huggingface.co/vblagoje/dpr-question_encoder-single-lfqa-wiki" target=_blank">encoder</a>
83
+ and sent to a server to find the most relevant Wikipedia passages. The Wikipedia <a href="https://huggingface.co/datasets/kilt_wikipedia" target=_blank">passages</a>
84
+ were previously encoded using a passage <a href="https://huggingface.co/vblagoje/dpr-ctx_encoder-single-lfqa-wiki" target=_blank">encoder</a> and
85
+ stored in the <a href="https://github.com/facebookresearch/faiss" target=_blank">Faiss</a> index. The question matching passages (a.k.a context passages) are retrieved from the Faiss
86
+ index and passed to a BART-based seq2seq <a href="https://huggingface.co/vblagoje/bart_lfqa" target=_blank">model</a> to
87
+ synthesize an original answer to the question.
88
+
89
+ </div>
90
+ """
91
+ st.markdown(techinical_intro, unsafe_allow_html=True)
92
+
pages/settings.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ settings = {}
4
+
5
+ def app():
6
+ st.markdown("""
7
+ <style>
8
+ div[data-testid="stForm"] {
9
+ border: 0;
10
+ }
11
+ .footer-custom {
12
+ position: fixed;
13
+ bottom: 0;
14
+ width: 100%;
15
+ color: var(--text-color);
16
+ max-width: 698px;
17
+ font-size: 14px;
18
+ height: 50px;
19
+ padding: 10px 0;
20
+ z-index: 50;
21
+ }
22
+ footer {
23
+ display: none !important;
24
+ }
25
+ .footer-custom a {
26
+ color: var(--text-color);
27
+ }
28
+ button[kind="formSubmit"]{
29
+ margin-top: 40px;
30
+ border-radius: 20px;
31
+ padding: 5px 20px;
32
+ font-size: 18px;
33
+ background-color: var(--primary-color);
34
+ }
35
+ #lfqa-model-parameters {
36
+ margin-bottom: 50px;
37
+ font-size: 36px;
38
+ }
39
+ #tts-model-parameters {
40
+ font-size: 36px;
41
+ margin-top: 50px;
42
+ }
43
+ .stAlert {
44
+ width: 250px;
45
+ margin-top: 32px;
46
+ }
47
+ </style>
48
+ """, unsafe_allow_html=True)
49
+
50
+ with st.form("settings"):
51
+ footer = """
52
+ <div class="footer-custom">
53
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
54
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
55
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
56
+ </div>
57
+ """
58
+ st.markdown(footer, unsafe_allow_html=True)
59
+
60
+ st.title("LFQA model parameters")
61
+
62
+ settings["min_length"] = st.slider("Min length", 20, 80, st.session_state["min_length"],
63
+ help="Min response length (words)")
64
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
65
+ settings["max_length"] = st.slider("Max length", 128, 320, st.session_state["max_length"],
66
+ help="Max response length (words)")
67
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
68
+ col1, col2 = st.columns(2)
69
+ with col1:
70
+ settings["do_sample"] = st.checkbox("Use sampling", st.session_state["do_sample"],
71
+ help="Whether or not to use sampling ; use greedy decoding otherwise.")
72
+ with col2:
73
+ settings["early_stopping"] = st.checkbox("Early stopping", st.session_state["early_stopping"],
74
+ help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not.")
75
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
76
+ settings["num_beams"] = st.slider("Num beams", 1, 16, st.session_state["num_beams"],
77
+ help="Number of beams for beam search. 1 means no beam search.")
78
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
79
+ settings["temperature"] = st.slider("Temperature", 0.0, 1.0, st.session_state["temperature"], step=0.1,
80
+ help="The value used to module the next token probabilities")
81
+
82
+ st.title("TTS model parameters")
83
+ settings["tts"] = st.selectbox(label="Engine", options=("Google", "HuggingFace"),
84
+ index=["Google", "HuggingFace"].index(st.session_state["tts"]),
85
+ help="Answer text-to-speech engine")
86
+
87
+ # Every form must have a submit button.
88
+ col3, col4, col5, col6 = st.columns(4)
89
+ with col3:
90
+ submitted = st.form_submit_button("Save")
91
+ with col4:
92
+ if submitted:
93
+ for k, v in settings.items():
94
+ st.session_state[k] = v
95
+ st.success('App settings saved successfully.')
requirements-dev.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ datasets
3
+ transformers
4
+ sentence_transformers
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk
2
+ click==8.0.3
3
+ streamlit==1.9.0
4
+ sentence_transformers
5
+ requests
6
+ pyjwt
7
+ typing
8
+ streamlit_option_menu
9
+ google-auth
10
+ google-cloud-texttospeech
11
+ protobuf<=3.20.1
style.css ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .row-widget.stTextInput > div:first-of-type {
3
+ background: #fff;
4
+ display: flex;
5
+ border: 1px solid #dfe1e5;
6
+ box-shadow: none;
7
+ border-radius: 24px;
8
+ height: 50px;
9
+ width: auto;
10
+ margin: 10px auto 30px;
11
+ }
12
+
13
+ .row-widget.stTextInput > div:first-of-type:hover,
14
+ .row-widget.stTextInput > div:first-of-type:focus {
15
+ box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.2);
16
+ }
17
+
18
+ .row-widget.stTextInput .st-bq {
19
+ background-color: #fff;
20
+ }
21
+
22
+ .row-widget.stTextInput > label {
23
+ color: #b3b3b3;
24
+ }
25
+
26
+ .row-widget.stButton > button {
27
+ border-radius: 24px;
28
+ background-color: #B6C9B1;
29
+ color: #fff;
30
+ border: none;
31
+ padding: 6px 20px;
32
+ float: right;
33
+ background-image: none;
34
+ }
35
+
36
+ .row-widget.stButton > button:hover {
37
+ box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.2);
38
+ }
39
+
40
+ .row-widget.stButton > button:focus {
41
+ border: none;
42
+ color: #fff;
43
+ }
44
+
45
+ .footer-custom {
46
+ position: fixed;
47
+ bottom: 0;
48
+ width: 100%;
49
+ color: var(--text-color);
50
+ max-width: 698px;
51
+ font-size: 14px;
52
+ height: 50px;
53
+ padding: 10px 0;
54
+ z-index: 50;
55
+ }
56
+
57
+ .main {
58
+ padding: 20px;
59
+ }
60
+
61
+ footer {
62
+ display: none !important;
63
+ }
64
+
65
+ .footer-custom a {
66
+ color: var(--text-color);
67
+ }
68
+
69
+ #wikipedia-assistant {
70
+ font-size: 36px;
71
+ }
72
+
73
+ .generated-answer p {
74
+ font-size: 16px;
75
+ font-weight: bold;
76
+ }
77
+
78
+ .react-json-view {
79
+ margin: 40px 0 80px;
80
+ }
81
+
82
+ .tooltip {
83
+ text-align: center;
84
+ line-height: 20px;
85
+ display: table-caption;
86
+ font-size: 10px;
87
+ border-radius: 50%;
88
+ height: 20px;
89
+ width: 20px;
90
+ position: relative;
91
+ cursor: pointer;
92
+ color:#000;
93
+ }
94
+
95
+ .tooltip .tooltiptext {
96
+ visibility: hidden;
97
+ width: 280px;
98
+ text-align: center;
99
+ border-radius: 6px;
100
+ padding: 10px;
101
+ position: absolute;
102
+ z-index: 1;
103
+ top: 25px;
104
+ left: 50%;
105
+ margin-left: -140px;
106
+ font-size: 14px;
107
+ background-color: #fff;
108
+ border: 1px solid #ccc;
109
+ box-shadow: 0px 0px 3px 1px rgba(0, 0, 0, 0.16);
110
+ color: #000;
111
+ }
112
+
113
+ .tooltip:hover .tooltiptext {
114
+ visibility: visible;
115
+ }
116
+
117
+ .sentence-wrapper {
118
+ border-left: 4px solid #ffc423;
119
+ padding-left: 20px;
120
+ margin-bottom: 40px;
121
+ }
122
+
123
+ #context {
124
+ padding: 2rem 0 1rem;
125
+ }
126
+
127
+ hr {
128
+ margin: 2em 0 1em;
129
+ }
130
+
131
+ .technical-details-info {
132
+ margin-bottom: 100px;
133
+ }
134
+
135
+ .loader-wrapper {
136
+ display: flex;
137
+ align-items: center;
138
+ background-color: rgba(250, 202, 43, 0.2);
139
+ padding: 15px 20px;
140
+ border-radius: 6px;
141
+ }
142
+
143
+ .loader-wrapper p {
144
+ margin-bottom: 0;
145
+ margin-left: 20px;
146
+ }
147
+
148
+ .loader {
149
+ width: 30px;
150
+ height: 30px;
151
+ border: dotted 5px #868686;
152
+ border-radius: 100%;
153
+ animation: spin 1s linear infinite;
154
+ }
155
+
156
+ .loader-note {
157
+ font-size: 14px;
158
+ color: #b3b3b3;
159
+ margin-left: 5px;
160
+ }
161
+
162
+ @keyframes spin {
163
+ 0% {
164
+ transform: rotate(0deg) scale(0.8);
165
+ border-top-color: transparent;
166
+ border-right-color: transparent;
167
+ }
168
+ 50% { transform: rotate(180deg) scale(1.2);
169
+ border-color: #949494;
170
+ border-top-color: transparent;
171
+ border-right-color: transparent;
172
+ }
173
+ 100% { transform: rotate(360deg) scale(0.8);
174
+ border-color: #bbbbbb;
175
+ border-top-color: transparent;
176
+ border-right-color: transparent;
177
+ }
178
+ }
training/run_retriever_no_trainer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import logging
4
+ import math
5
+ from random import choice, randint
6
+
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from datasets import load_dataset
11
+ from torch.utils import checkpoint
12
+ from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
13
+ from tqdm.auto import tqdm
14
+ from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSequenceClassification
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_parser():
20
+ parser = argparse.ArgumentParser(description="Train ELI5 retriever")
21
+ parser.add_argument(
22
+ "--dataset_name",
23
+ type=str,
24
+ default="vblagoje/lfqa",
25
+ help="The name of the dataset to use (via the datasets library).",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--per_device_train_batch_size",
30
+ type=int,
31
+ default=1024,
32
+ )
33
+
34
+ parser.add_argument(
35
+ "--per_device_eval_batch_size",
36
+ type=int,
37
+ default=1024,
38
+ help="Batch size (per device) for the evaluation dataloader.",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--max_length",
43
+ type=int,
44
+ default=128,
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--checkpoint_batch_size",
49
+ type=int,
50
+ default=32,
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--pretrained_model_name",
55
+ type=str,
56
+ default="google/bert_uncased_L-8_H-768_A-12",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--model_save_name",
61
+ type=str,
62
+ default="eli5_retriever_model_l-12_h-768_b-512-512",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--learning_rate",
67
+ type=float,
68
+ default=2e-4,
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--weight_decay",
73
+ type=float,
74
+ default=0.2,
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--log_freq",
79
+ type=int,
80
+ default=500,
81
+ help="Log train/validation loss every log_freq update steps"
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--num_train_epochs",
86
+ type=int,
87
+ default=4,
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--max_train_steps",
92
+ type=int,
93
+ default=None,
94
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--gradient_accumulation_steps",
99
+ type=int,
100
+ default=1,
101
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--lr_scheduler_type",
106
+ type=SchedulerType,
107
+ default="linear", # this is linear with warmup
108
+ help="The scheduler type to use.",
109
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--num_warmup_steps",
114
+ type=int,
115
+ default=100,
116
+ help="Number of steps for the warmup in the lr scheduler."
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--warmup_percentage",
121
+ type=float,
122
+ default=0.08,
123
+ help="Number of steps for the warmup in the lr scheduler."
124
+ )
125
+ return parser
126
+
127
+
128
+ class RetrievalQAEmbedder(torch.nn.Module):
129
+ def __init__(self, sent_encoder):
130
+ super(RetrievalQAEmbedder, self).__init__()
131
+ dim = sent_encoder.config.hidden_size
132
+ self.bert_query = sent_encoder
133
+ self.output_dim = 128
134
+ self.project_query = torch.nn.Linear(dim, self.output_dim, bias=False)
135
+ self.project_doc = torch.nn.Linear(dim, self.output_dim, bias=False)
136
+ self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
137
+
138
+ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
139
+ # reproduces BERT forward pass with checkpointing
140
+ if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
141
+ return self.bert_query(input_ids, attention_mask=attention_mask)[1]
142
+ else:
143
+ # prepare implicit variables
144
+ device = input_ids.device
145
+ input_shape = input_ids.size()
146
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
147
+ head_mask = [None] * self.bert_query.config.num_hidden_layers
148
+ extended_attention_mask: torch.Tensor = self.bert_query.get_extended_attention_mask(
149
+ attention_mask, input_shape, device
150
+ )
151
+
152
+ # define function for checkpointing
153
+ def partial_encode(*inputs):
154
+ encoder_outputs = self.bert_query.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask, )
155
+ sequence_output = encoder_outputs[0]
156
+ pooled_output = self.bert_query.pooler(sequence_output)
157
+ return pooled_output
158
+
159
+ # run embedding layer on everything at once
160
+ embedding_output = self.bert_query.embeddings(
161
+ input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
162
+ )
163
+ # run encoding and pooling on one mini-batch at a time
164
+ pooled_output_list = []
165
+ for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
166
+ b_embedding_output = embedding_output[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
167
+ b_attention_mask = extended_attention_mask[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
168
+ pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
169
+ pooled_output_list.append(pooled_output)
170
+ return torch.cat(pooled_output_list, dim=0)
171
+
172
+ def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
173
+ q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
174
+ return self.project_query(q_reps)
175
+
176
+ def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
177
+ a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
178
+ return self.project_doc(a_reps)
179
+
180
+ def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
181
+ device = q_ids.device
182
+ q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
183
+ a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
184
+ compare_scores = torch.mm(q_reps, a_reps.t())
185
+ loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
186
+ loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
187
+ loss = (loss_qa + loss_aq) / 2
188
+ return loss
189
+
190
+
191
+ class ELI5DatasetQARetriever(Dataset):
192
+ def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
193
+ self.data = examples_array
194
+ self.answer_thres = extra_answer_threshold
195
+ self.min_length = min_answer_length
196
+ self.training = training
197
+ self.n_samples = self.data.num_rows if n_samples is None else n_samples
198
+
199
+ def __len__(self):
200
+ return self.n_samples
201
+
202
+ def make_example(self, idx):
203
+ example = self.data[idx]
204
+ question = example["title"]
205
+ if self.training:
206
+ answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
207
+ answer_tab = choice(answers).split(" ")
208
+ start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
209
+ answer_span = " ".join(answer_tab[start_idx:])
210
+ else:
211
+ answer_span = example["answers"]["text"][0]
212
+ return question, answer_span
213
+
214
+ def __getitem__(self, idx):
215
+ return self.make_example(idx % self.data.num_rows)
216
+
217
+
218
+ def make_qa_retriever_batch(qa_list, tokenizer, max_len=64):
219
+ q_ls = [q for q, a in qa_list]
220
+ a_ls = [a for q, a in qa_list]
221
+ q_toks = tokenizer(q_ls, padding="max_length", max_length=max_len, truncation=True)
222
+ q_ids, q_mask = (
223
+ torch.LongTensor(q_toks["input_ids"]),
224
+ torch.LongTensor(q_toks["attention_mask"])
225
+ )
226
+ a_toks = tokenizer(a_ls, padding="max_length", max_length=max_len, truncation=True)
227
+ a_ids, a_mask = (
228
+ torch.LongTensor(a_toks["input_ids"]),
229
+ torch.LongTensor(a_toks["attention_mask"]),
230
+ )
231
+ return q_ids, q_mask, a_ids, a_mask
232
+
233
+
234
+ def evaluate_qa_retriever(model, data_loader):
235
+ # make iterator
236
+ epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
237
+ tot_loss = 0.0
238
+ with torch.no_grad():
239
+ for step, batch in enumerate(epoch_iterator):
240
+ q_ids, q_mask, a_ids, a_mask = batch
241
+ loss = model(q_ids, q_mask, a_ids, a_mask)
242
+ tot_loss += loss.item()
243
+ return tot_loss / (step + 1)
244
+
245
+
246
+ def train(config):
247
+ set_seed(42)
248
+ args = config["args"]
249
+ data_files = {"train": "train.json", "validation": "validation.json", "test": "test.json"}
250
+ eli5 = load_dataset(args.dataset_name, data_files=data_files)
251
+
252
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
253
+ accelerator = Accelerator()
254
+ # Make one log on every process with the configuration for debugging.
255
+ logging.basicConfig(
256
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
257
+ datefmt="%m/%d/%Y %H:%M:%S",
258
+ level=logging.INFO,
259
+ )
260
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
261
+ logger.info(accelerator.state)
262
+
263
+ # prepare torch Dataset objects
264
+ train_dataset = ELI5DatasetQARetriever(eli5['train'], training=True)
265
+ valid_dataset = ELI5DatasetQARetriever(eli5['validation'], training=False)
266
+
267
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
268
+ base_model = AutoModel.from_pretrained(args.pretrained_model_name)
269
+
270
+ model = RetrievalQAEmbedder(base_model)
271
+ no_decay = ['bias', 'LayerNorm.weight']
272
+ optimizer_grouped_parameters = [
273
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
274
+ 'weight_decay': args.weight_decay},
275
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
276
+ ]
277
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
278
+
279
+ model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
280
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
281
+ sampler=RandomSampler(train_dataset), collate_fn=model_collate_fn)
282
+
283
+ model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
284
+ eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
285
+ sampler=SequentialSampler(valid_dataset), collate_fn=model_collate_fn)
286
+
287
+ # train the model
288
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
289
+ train_dataloader, eval_dataloader)
290
+ # Scheduler and math around the number of training steps.
291
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
292
+ if args.max_train_steps is None:
293
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
294
+ else:
295
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
296
+
297
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
298
+ args.warmup_percentage)
299
+ scheduler = get_scheduler(
300
+ name=args.lr_scheduler_type,
301
+ optimizer=optimizer,
302
+ num_warmup_steps=args.num_warmup_steps,
303
+ num_training_steps=args.max_train_steps,
304
+ )
305
+
306
+ # Train!
307
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
308
+
309
+ logger.info("***** Running training *****")
310
+ logger.info(f" Num examples = {len(train_dataset)}")
311
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
312
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
313
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
314
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
315
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
316
+ logger.info(f" Warmup steps = {num_warmup_steps}")
317
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
318
+
319
+ loc_loss = 0.0
320
+ current_loss = 0.0
321
+ checkpoint_step = 0
322
+
323
+ completed_steps = checkpoint_step
324
+ progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
325
+ disable=not accelerator.is_local_main_process)
326
+ for epoch in range(args.num_train_epochs):
327
+ model.train()
328
+ batch = next(iter(train_dataloader))
329
+ for step in range(1000):
330
+ #for step, batch in enumerate(train_dataloader, start=checkpoint_step):
331
+ # model inputs
332
+ q_ids, q_mask, a_ids, a_mask = batch
333
+ pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
334
+ loss = pre_loss.sum() / args.gradient_accumulation_steps
335
+ accelerator.backward(loss)
336
+ loc_loss += loss.item()
337
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
338
+ current_loss = loc_loss
339
+ optimizer.step()
340
+ scheduler.step()
341
+ optimizer.zero_grad()
342
+ progress_bar.update(1)
343
+ progress_bar.set_postfix(loss=loc_loss)
344
+ loc_loss = 0
345
+ completed_steps += 1
346
+
347
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
348
+ accelerator.wait_for_everyone()
349
+ unwrapped_model = accelerator.unwrap_model(model)
350
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
351
+ logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
352
+ if args.wandb and accelerator.is_local_main_process:
353
+ import wandb
354
+ wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
355
+
356
+ if completed_steps >= args.max_train_steps:
357
+ break
358
+
359
+ logger.info("Saving model {}".format(args.model_save_name))
360
+ accelerator.wait_for_everyone()
361
+ unwrapped_model = accelerator.unwrap_model(model)
362
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
363
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
364
+ logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
365
+
366
+
367
+ if __name__ == "__main__":
368
+ parser = get_parser()
369
+ parser.add_argument(
370
+ "--wandb",
371
+ action="store_true",
372
+ help="Whether to use W&B logging",
373
+ )
374
+ main_args, _ = parser.parse_known_args()
375
+ config = {"args": main_args}
376
+ if main_args.wandb:
377
+ import wandb
378
+ wandb.init(project="Retriever")
379
+
380
+ train(config=config)
381
+
training/run_retriever_no_trainer_gpl.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import List, Any, Union, Optional
6
+
7
+ import torch
8
+ import ujson
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import set_seed
11
+ from torch import nn, Tensor
12
+ from torch.nn import functional as F
13
+ from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
14
+ from tqdm.auto import tqdm
15
+ from transformers import get_scheduler, AutoTokenizer, AutoModel, AdamW, SchedulerType, PreTrainedTokenizerBase, AutoModelForSequenceClassification, BatchEncoding
16
+ from transformers.file_utils import PaddingStrategy
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_parser():
22
+ parser = argparse.ArgumentParser(description="Train LFQA retriever")
23
+ parser.add_argument(
24
+ "--dpr_input_file",
25
+ type=str,
26
+ help="DPR formatted input file with question/positive/negative pairs in a JSONL file",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--per_device_train_batch_size",
31
+ type=int,
32
+ default=32,
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--per_device_eval_batch_size",
37
+ type=int,
38
+ default=32,
39
+ help="Batch size (per device) for the evaluation dataloader.",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--max_length",
44
+ type=int,
45
+ default=128,
46
+ )
47
+
48
+
49
+ parser.add_argument(
50
+ "--pretrained_model_name",
51
+ type=str,
52
+ default="sentence-transformers/all-MiniLM-L6-v2",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--ce_model_name",
57
+ type=str,
58
+ default="cross-encoder/ms-marco-MiniLM-L-6-v2",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--model_save_name",
63
+ type=str,
64
+ default="eli5_retriever_model_l-12_h-768_b-512-512",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--learning_rate",
69
+ type=float,
70
+ default=2e-5,
71
+ )
72
+
73
+ parser.add_argument(
74
+ "--weight_decay",
75
+ type=float,
76
+ default=0.01,
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--log_freq",
81
+ type=int,
82
+ default=500,
83
+ help="Log train/validation loss every log_freq update steps"
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--num_train_epochs",
88
+ type=int,
89
+ default=4,
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--max_train_steps",
94
+ type=int,
95
+ default=None,
96
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--gradient_accumulation_steps",
101
+ type=int,
102
+ default=1,
103
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--lr_scheduler_type",
108
+ type=SchedulerType,
109
+ default="linear", # this is linear with warmup
110
+ help="The scheduler type to use.",
111
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--num_warmup_steps",
116
+ type=int,
117
+ default=100,
118
+ help="Number of steps for the warmup in the lr scheduler."
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--warmup_percentage",
123
+ type=float,
124
+ default=0.08,
125
+ help="Number of steps for the warmup in the lr scheduler."
126
+ )
127
+ return parser
128
+
129
+
130
+ @dataclass
131
+ class InputExample:
132
+ guid: str = ""
133
+ texts: List[str] = None
134
+ label: Union[int, float] = 0
135
+
136
+
137
+ class DPRDataset(Dataset):
138
+ """
139
+ Dataset DPR format of question, answers, positive, negative, and hard negative passages
140
+ See https://github.com/facebookresearch/DPR#retriever-input-data-format for more details
141
+ """
142
+
143
+ def __init__(self, file_path: str, include_all_positive: bool = False) -> None:
144
+ super().__init__()
145
+ with open(file_path, "r") as fp:
146
+ self.data = []
147
+
148
+ def dpr_example_to_input_example(idx, dpr_item):
149
+ examples = []
150
+ for p_idx, p_item in enumerate(dpr_item["positive_ctxs"]):
151
+ for n_idx, n_item in enumerate(dpr_item["negative_ctxs"]):
152
+ examples.append(InputExample(guid=[idx, p_idx, n_idx], texts=[dpr_item["question"],
153
+ p_item["text"],
154
+ n_item["text"]]))
155
+ if not include_all_positive:
156
+ break
157
+ return examples
158
+
159
+ for idx, line in enumerate(fp):
160
+ self.data.extend(dpr_example_to_input_example(idx, ujson.loads(line)))
161
+
162
+ def __len__(self):
163
+ return len(self.data)
164
+
165
+ def __getitem__(self, index):
166
+ return self.data[index]
167
+
168
+
169
+ def dpr_collate_fn(batch):
170
+ query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
171
+ query, pos, neg = zip(*[example.texts for example in batch])
172
+ return (query_id, pos_id, neg_id), (query, pos, neg)
173
+
174
+
175
+ # Mean Pooling - Take attention mask into account for correct averaging
176
+ def mean_pooling(model_output, attention_mask):
177
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
178
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
179
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
180
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
181
+ return sum_embeddings / sum_mask
182
+
183
+
184
+ @dataclass
185
+ class CrossEncoderCollator:
186
+ tokenizer: PreTrainedTokenizerBase
187
+ model: Any
188
+ target_tokenizer: PreTrainedTokenizerBase
189
+ padding: Union[bool, str, PaddingStrategy] = True
190
+ max_length: Optional[int] = None
191
+ pad_to_multiple_of: Optional[int] = None
192
+ return_tensors: str = "pt"
193
+
194
+ def __call__(self, batch):
195
+ query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
196
+ query, pos_passage, neg_passage = zip(*[example.texts for example in batch])
197
+ batch_input: List[List[str]] = list(zip(query, pos_passage)) + list(zip(query, neg_passage))
198
+ features = self.tokenizer(batch_input, padding=self.padding, truncation=True,
199
+ return_tensors=self.return_tensors)
200
+ with torch.no_grad():
201
+ scores = self.model(**features).logits
202
+
203
+ labels = scores[:len(query)] - scores[len(query):]
204
+ batch_input: List[str] = list(query) + list(pos_passage) + list(neg_passage)
205
+ #breakpoint()
206
+ encoded_input = self.target_tokenizer(batch_input, padding=True, truncation=True,
207
+ max_length=256, return_tensors='pt')
208
+
209
+ encoded_input["labels"] = labels
210
+
211
+ return encoded_input
212
+
213
+
214
+ class RetrievalQAEmbedder(torch.nn.Module):
215
+ def __init__(self, sent_encoder, sent_tokenizer, batch_size:int = 32):
216
+ super(RetrievalQAEmbedder, self).__init__()
217
+ dim = sent_encoder.config.hidden_size
218
+ self.model = sent_encoder
219
+ self.tokenizer = sent_tokenizer
220
+ self.scale = 1
221
+ self.similarity_fct = 'dot'
222
+ self.batch_size = 32
223
+ self.loss_fct = nn.MSELoss()
224
+
225
+ def forward(self, examples: BatchEncoding):
226
+ # Tokenize sentences
227
+ labels = examples.pop("labels")
228
+ # Compute token embeddings
229
+ model_output = self.model(**examples)
230
+
231
+ examples["labels"] = labels
232
+
233
+ # Perform pooling. In this case, mean pooling
234
+ sentence_embeddings = mean_pooling(model_output, examples['attention_mask'])
235
+ target_shape = (3, self.batch_size, sentence_embeddings.shape[-1])
236
+ sentence_embeddings_reshaped = torch.reshape(sentence_embeddings, target_shape)
237
+
238
+ #breakpoint()
239
+
240
+ embeddings_query = sentence_embeddings_reshaped[0]
241
+ embeddings_pos = sentence_embeddings_reshaped[1]
242
+ embeddings_neg = sentence_embeddings_reshaped[2]
243
+
244
+ if self.similarity_fct == 'cosine':
245
+ embeddings_query = F.normalize(embeddings_query, p=2, dim=1)
246
+ embeddings_pos = F.normalize(embeddings_pos, p=2, dim=1)
247
+ embeddings_neg = F.normalize(embeddings_neg, p=2, dim=1)
248
+
249
+ scores_pos = (embeddings_query * embeddings_pos).sum(dim=-1) * self.scale
250
+ scores_neg = (embeddings_query * embeddings_neg).sum(dim=-1) * self.scale
251
+ margin_pred = scores_pos - scores_neg
252
+ #breakpoint()
253
+ return self.loss_fct(margin_pred, labels.squeeze())
254
+
255
+
256
+ def evaluate_qa_retriever(model, data_loader):
257
+ # make iterator
258
+ epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
259
+ tot_loss = 0.0
260
+ with torch.no_grad():
261
+ for step, batch in enumerate(epoch_iterator):
262
+ q_ids, q_mask, a_ids, a_mask = batch
263
+ loss = model(q_ids, q_mask, a_ids, a_mask)
264
+ tot_loss += loss.item()
265
+ return tot_loss / (step + 1)
266
+
267
+
268
+ def train(config):
269
+ set_seed(42)
270
+ args = config["args"]
271
+
272
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
273
+ accelerator = Accelerator()
274
+ # Make one log on every process with the configuration for debugging.
275
+ logging.basicConfig(
276
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
277
+ datefmt="%m/%d/%Y %H:%M:%S",
278
+ level=logging.INFO,
279
+ )
280
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
281
+ logger.info(accelerator.state)
282
+
283
+ # prepare torch Dataset objects
284
+ train_dataset = DPRDataset(file_path=args.dpr_input_file)
285
+ valid_dataset = Dataset()
286
+
287
+ base_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
288
+ base_model = AutoModel.from_pretrained(args.pretrained_model_name)
289
+
290
+ ce_tokenizer = AutoTokenizer.from_pretrained(args.ce_model_name)
291
+ ce_model = AutoModelForSequenceClassification.from_pretrained(args.ce_model_name)
292
+ _ = ce_model.eval()
293
+
294
+ model = RetrievalQAEmbedder(base_model, base_tokenizer)
295
+ no_decay = ['bias', 'LayerNorm.weight']
296
+ optimizer_grouped_parameters = [
297
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
298
+ 'weight_decay': args.weight_decay},
299
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
300
+ ]
301
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
302
+
303
+ cec = CrossEncoderCollator(model=ce_model, tokenizer=ce_tokenizer, target_tokenizer=base_tokenizer)
304
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
305
+ sampler=RandomSampler(train_dataset), collate_fn=cec)
306
+
307
+ eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
308
+ sampler=SequentialSampler(valid_dataset), collate_fn=cec)
309
+
310
+ # train the model
311
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
312
+ train_dataloader, eval_dataloader)
313
+ # Scheduler and math around the number of training steps.
314
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
315
+ if args.max_train_steps is None:
316
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
317
+ else:
318
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
319
+
320
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
321
+ args.warmup_percentage)
322
+ scheduler = get_scheduler(
323
+ name=args.lr_scheduler_type,
324
+ optimizer=optimizer,
325
+ num_warmup_steps=args.num_warmup_steps,
326
+ num_training_steps=args.max_train_steps,
327
+ )
328
+
329
+ # Train!
330
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
331
+
332
+ logger.info("***** Running training *****")
333
+ logger.info(f" Num examples = {len(train_dataset)}")
334
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
335
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
336
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
337
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
338
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
339
+ logger.info(f" Warmup steps = {num_warmup_steps}")
340
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
341
+
342
+ loc_loss = 0.0
343
+ current_loss = 0.0
344
+ checkpoint_step = 0
345
+
346
+ completed_steps = checkpoint_step
347
+ progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
348
+ disable=not accelerator.is_local_main_process)
349
+ for epoch in range(args.num_train_epochs):
350
+ model.train()
351
+ for step, batch in enumerate(train_dataloader, start=checkpoint_step):
352
+ # model inputs
353
+ pre_loss = model(batch)
354
+ loss = pre_loss / args.gradient_accumulation_steps
355
+ accelerator.backward(loss)
356
+ loc_loss += loss.item()
357
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
358
+ current_loss = loc_loss
359
+ optimizer.step()
360
+ scheduler.step()
361
+ optimizer.zero_grad()
362
+ progress_bar.update(1)
363
+ progress_bar.set_postfix(loss=loc_loss)
364
+ loc_loss = 0
365
+ completed_steps += 1
366
+
367
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
368
+ # accelerator.wait_for_everyone()
369
+ # unwrapped_model = accelerator.unwrap_model(model)
370
+ # eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
371
+ eval_loss = 0
372
+ logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
373
+ if args.wandb and accelerator.is_local_main_process:
374
+ import wandb
375
+ wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
376
+
377
+ if completed_steps >= args.max_train_steps:
378
+ break
379
+
380
+ logger.info("Saving model {}".format(args.model_save_name))
381
+ accelerator.wait_for_everyone()
382
+ unwrapped_model = accelerator.unwrap_model(model)
383
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
384
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
385
+ logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
386
+
387
+
388
+ if __name__ == "__main__":
389
+ parser = get_parser()
390
+ parser.add_argument(
391
+ "--wandb",
392
+ action="store_true",
393
+ help="Whether to use W&B logging",
394
+ )
395
+ main_args, _ = parser.parse_known_args()
396
+ config = {"args": main_args}
397
+ if main_args.wandb:
398
+ import wandb
399
+
400
+ wandb.init(project="Retriever")
401
+
402
+ train(config=config)
403
+
training/run_seq2seq_no_trainer.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import re
5
+
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \
13
+ DataCollatorWithPadding
14
+
15
+ from datasets import load_dataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_parser():
21
+ parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model")
22
+ parser.add_argument(
23
+ "--dataset_name",
24
+ type=str,
25
+ default="vblagoje/lfqa",
26
+ help="The name of the dataset to use (via the datasets library).",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--per_device_train_batch_size",
31
+ type=int,
32
+ default=4,
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--per_device_eval_batch_size",
37
+ type=int,
38
+ default=4,
39
+ help="Batch size (per device) for the evaluation dataloader.",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--pretrained_model_name",
44
+ type=str,
45
+ default="facebook/bart-large",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--model_save_name",
50
+ type=str,
51
+ default="eli5_bart_model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--learning_rate",
56
+ type=float,
57
+ default=2e-4,
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--weight_decay",
62
+ type=float,
63
+ default=0.0,
64
+ help="Weight decay to use."
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--log_freq",
69
+ type=int,
70
+ default=100,
71
+ help="Log train/validation loss every log_freq update steps"
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--ignore_pad_token_for_loss",
76
+ type=bool,
77
+ default=True,
78
+ help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
79
+ )
80
+
81
+ parser.add_argument(
82
+ "--num_train_epochs",
83
+ type=int,
84
+ default=3,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--max_train_steps",
89
+ type=int,
90
+ default=None,
91
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--gradient_accumulation_steps",
96
+ type=int,
97
+ default=16,
98
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--pad_to_max_length",
103
+ action="store_true",
104
+ help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--max_source_length",
113
+ type=int,
114
+ default=1024,
115
+ help="The maximum total input sequence length after "
116
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--max_target_length",
121
+ type=int,
122
+ default=360,
123
+ help="The maximum total sequence length for target text after "
124
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--lr_scheduler_type",
129
+ type=SchedulerType,
130
+ default="linear", # this is linear with warmup
131
+ help="The scheduler type to use.",
132
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--num_warmup_steps",
137
+ type=int,
138
+ default=None,
139
+ help="Number of steps for the warmup in the lr scheduler."
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--warmup_percentage",
144
+ type=float,
145
+ default=0.08,
146
+ help="Number of steps for the warmup in the lr scheduler."
147
+ )
148
+ return parser
149
+
150
+
151
+ def cleanup_references(text):
152
+ # URL reference where we need to remove both the link text and URL
153
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
154
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
155
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
156
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
157
+
158
+ # URL reference where we need to preserve link text but remove URL
159
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
160
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
161
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
162
+
163
+ # lastly remove just dangling _URL_[0-9]_ URL references
164
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
165
+ return result
166
+
167
+
168
+ def clean_answer(text):
169
+ result = cleanup_references(text)
170
+ result = result.replace("\n", " ")
171
+ result = re.sub(r"\s\s+", " ", result)
172
+ result = re.sub(r"BULLET::::-", "", result)
173
+ return result.strip()
174
+
175
+
176
+ def clean_question(text):
177
+ result = cleanup_references(text)
178
+ result = result.replace("\n", " ")
179
+ result = re.sub(r"\s\s+", " ", result)
180
+ result = result.replace("[deleted]", "")
181
+ return result.lower().strip()
182
+
183
+
184
+ def prepare_support_docs(example):
185
+ provenances = example["output"][-1]["provenance"]
186
+ context = "<P> " + " <P> ".join([p["text"] for p in provenances])
187
+ return {"context": context}
188
+
189
+
190
+ def preprocess_eli5(examples, **fn_kwargs):
191
+ document_cache = fn_kwargs["document_cache"]
192
+ training = fn_kwargs.get("training", True)
193
+ extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3)
194
+ include_selftext = fn_kwargs.get("include_selftext", False)
195
+ exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", [])
196
+
197
+ questions, contexts, answers = [], [], []
198
+ for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"],
199
+ examples["answers"]):
200
+ accepted_answer_idx = []
201
+ if training:
202
+ accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if
203
+ score > extra_answer_threshold]
204
+ if not training or not accepted_answer_idx:
205
+ accepted_answer_idx = [0]
206
+ document = document_cache[q_id]
207
+ for idx in accepted_answer_idx:
208
+ skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns])
209
+ if skip_answer:
210
+ continue
211
+ if include_selftext:
212
+ questions.append(clean_question(f"{question} {selftext}"))
213
+ else:
214
+ questions.append(clean_question(question))
215
+ contexts.append(document.lower().strip())
216
+ answers.append(clean_answer(answer["text"][idx]))
217
+
218
+ return {"question": questions, "context": contexts, "answer": answers}
219
+
220
+
221
+ def eval_qa_s2s_epoch(model, dataloader, accelerator, args):
222
+ model.eval()
223
+ num_eval_steps = math.ceil(len(dataloader))
224
+ progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process)
225
+ total_loss = 0.
226
+ with torch.no_grad():
227
+ for step, batch in enumerate(dataloader):
228
+ outputs = model(**batch)
229
+ loss = outputs.loss
230
+ total_loss += loss.item()
231
+ progress_bar.update(1)
232
+ progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3))
233
+ return total_loss / (step + 1)
234
+
235
+
236
+ def train(config):
237
+ set_seed(42)
238
+ args = config["args"]
239
+ eli5 = load_dataset(args.dataset_name)
240
+
241
+ support_docs = load_dataset("vblagoje/lfqa_support_docs")
242
+
243
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
244
+ accelerator = Accelerator()
245
+ # Make one log on every process with the configuration for debugging.
246
+ logging.basicConfig(
247
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
248
+ datefmt="%m/%d/%Y %H:%M:%S",
249
+ level=logging.INFO,
250
+ )
251
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
252
+ logger.info(accelerator.state)
253
+
254
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
255
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name)
256
+
257
+ # Optimizer
258
+ # Split weights in two groups, one with weight decay and the other not.
259
+ no_decay = ["bias", "LayerNorm.weight"]
260
+ optimizer_grouped_parameters = [
261
+ {
262
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
263
+ "weight_decay": args.weight_decay,
264
+ },
265
+ {
266
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
267
+ "weight_decay": 0.0,
268
+ },
269
+ ]
270
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
271
+
272
+ processed_datasets = {}
273
+ support_docs_prepared = {}
274
+ with accelerator.main_process_first():
275
+ for split in ["train", "validation"]:
276
+ support_docs_prepared[split] = support_docs[split].map(prepare_support_docs,
277
+ batched=False,
278
+ cache_file_name=f"./support_docs_{split}.arrow",
279
+ load_from_cache_file=not args.overwrite_cache,
280
+ desc="Preparing support docs",
281
+ )
282
+ column_names = eli5["train"].column_names
283
+ for split in ["train", "validation"]:
284
+ d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split],
285
+ desc=f"Adding support docs to LFQA {split}")])
286
+ processed_datasets[split] = eli5[split].map(preprocess_eli5,
287
+ batched=True,
288
+ remove_columns=column_names,
289
+ cache_file_name=f"./processed_datasets_{split}.arrow",
290
+ load_from_cache_file=not args.overwrite_cache,
291
+ desc="Preparing dataset for tokenization",
292
+ fn_kwargs={"document_cache": d_cache,
293
+ "training": split == "train",
294
+ "exclude_answer_patterns": [re.compile("not sure what you"),
295
+ re.compile("\n\n >")]}
296
+ )
297
+
298
+ padding = "max_length" if args.pad_to_max_length else False
299
+ # Temporarily set max_target_length for training.
300
+ max_target_length = args.max_target_length
301
+
302
+ label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
303
+
304
+ def tokenize_dataset(examples):
305
+ inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])]
306
+ targets = examples["answer"]
307
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
308
+
309
+ # Setup the tokenizer for targets
310
+ with tokenizer.as_target_tokenizer():
311
+ labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True,
312
+ return_tensors="np")
313
+
314
+ model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist()
315
+ # replace pad_token_id with label_pad_token_id to avoid loss calculation on those tokens
316
+ labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id,
317
+ label_pad_token_id, labels["input_ids"])
318
+
319
+ model_inputs["labels"] = labels["input_ids"][:, 1:].tolist()
320
+ return model_inputs
321
+
322
+ tokenized_datasets = {}
323
+ with accelerator.main_process_first():
324
+ for split, dataset in processed_datasets.items():
325
+ tokenized_datasets[split] = dataset.map(
326
+ tokenize_dataset,
327
+ batched=True,
328
+ cache_file_name=f"./tokenized_dataset_{split}.arrow",
329
+ remove_columns=dataset.column_names,
330
+ load_from_cache_file=not args.overwrite_cache,
331
+ desc="Running tokenizer on dataset"
332
+ )
333
+
334
+ train_dataset = tokenized_datasets["train"]
335
+ eval_dataset = tokenized_datasets["validation"]
336
+ train_dataset.set_format(type='torch')
337
+ eval_dataset.set_format(type='torch')
338
+
339
+ data_collator = DataCollatorWithPadding(tokenizer, "max_length")
340
+
341
+ # first epoch we don't shuffle
342
+ train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size,
343
+ collate_fn=data_collator)
344
+ eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator)
345
+
346
+ # train the model
347
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,
348
+ eval_dataloader)
349
+ # Scheduler and math around the number of training steps.
350
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
351
+ if args.max_train_steps is None:
352
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
353
+ else:
354
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
355
+
356
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
357
+ args.warmup_percentage)
358
+ scheduler = get_scheduler(
359
+ name=args.lr_scheduler_type,
360
+ optimizer=optimizer,
361
+ num_warmup_steps=num_warmup_steps,
362
+ num_training_steps=args.max_train_steps,
363
+ )
364
+ # Train!
365
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
366
+
367
+ logger.info("***** Running training *****")
368
+ logger.info(f" Num examples = {len(train_dataset)}")
369
+ logger.info(f" Num eval examples = {len(eval_dataset)}")
370
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
371
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
372
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
373
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
374
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
375
+ logger.info(f" Warmup steps = {num_warmup_steps}")
376
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
377
+
378
+ # Only show the progress bar once on each machine.
379
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
380
+ completed_steps = 0
381
+ switched_train_dataloader = False
382
+ for epoch in range(args.num_train_epochs):
383
+ model.train()
384
+ if epoch > 0 and not switched_train_dataloader:
385
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
386
+ shuffle=True, collate_fn=data_collator)
387
+ train_dataloader = accelerator.prepare(train_dataloader)
388
+ switched_train_dataloader = True
389
+
390
+ for step, batch in enumerate(train_dataloader):
391
+ outputs = model(**batch)
392
+ loss = torch.mean(outputs.loss)
393
+ accelerator.backward(loss)
394
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
395
+ optimizer.step()
396
+ scheduler.step()
397
+ optimizer.zero_grad()
398
+ progress_bar.update(1)
399
+ progress_bar.set_postfix(loss=round(loss.item(), 3))
400
+ completed_steps += 1
401
+
402
+ if completed_steps >= args.max_train_steps:
403
+ break
404
+
405
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
406
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
407
+ model.train()
408
+ logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}")
409
+ if args.wandb and accelerator.is_local_main_process:
410
+ import wandb
411
+ wandb.log({"loss": loss.item(),
412
+ "lr": scheduler.get_last_lr()[0],
413
+ "validation_loss": validation_loss,
414
+ "completed_steps": completed_steps})
415
+
416
+ logger.info("Saving model {}".format(args.model_save_name))
417
+ accelerator.wait_for_everyone()
418
+ unwrapped_model = accelerator.unwrap_model(model)
419
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
420
+
421
+ # Calculating the validation loss over epoch
422
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
423
+
424
+ logger.info("Epoch: {}".format(epoch))
425
+ logger.info("Validation loss: {}".format(validation_loss))
426
+
427
+
428
+ def main():
429
+ parser = get_parser()
430
+ parser.add_argument(
431
+ "--wandb",
432
+ action="store_true",
433
+ help="If true, use W&B logging",
434
+ )
435
+ main_args, _ = parser.parse_known_args()
436
+ config = {"args": main_args}
437
+ if main_args.wandb:
438
+ import wandb
439
+ wandb.init(project="Bart_ELI5")
440
+ train(config=config)
441
+
442
+
443
+ main()
444
+
445
+
446
+
util/common.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+
5
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
6
+ 'wikidata_info', 'history']
7
+
8
+ kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
9
+ 'end_character', 'title', 'section', 'text']
10
+
11
+
12
+ def clean_question(text):
13
+ result = cleanup_references(text)
14
+ result = result.replace("\n", " ")
15
+ result = re.sub(r"\s\s+", " ", result)
16
+ result = result.replace("[deleted]", "")
17
+ return result.lower().strip()
18
+
19
+
20
+ def cleanup_references(text):
21
+ # URL reference where we need to remove both the link text and URL
22
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
23
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
24
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
25
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
26
+
27
+ # URL reference where we need to preserve link text but remove URL
28
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
29
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
30
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
31
+
32
+ # lastly remove just dangling _URL_[0-9]_ URL references
33
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
34
+ return result
35
+
36
+
37
+ def clean_answer(text):
38
+ result = cleanup_references(text)
39
+ result = result.replace("\n", " ")
40
+ result = re.sub(r"\s\s+", " ", result)
41
+ result = re.sub(r"BULLET::::-", "", result)
42
+ return trim(result.strip())
43
+
44
+
45
+ def trim(text, word_count: int = 100):
46
+ return " ".join(text.split(" ")[:word_count])
47
+
48
+
49
+ def articles_to_paragraphs(examples):
50
+ ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
51
+ for bidx, example in enumerate(examples["text"]):
52
+ last_section = ""
53
+ for idx, p in enumerate(example["paragraph"]):
54
+ if "Section::::" in p:
55
+ last_section = p
56
+ ids.append(examples["wikipedia_id"][bidx])
57
+ titles.append(examples["wikipedia_title"][bidx])
58
+ sections.append(last_section)
59
+ texts.append(p)
60
+ start_ps.append(idx)
61
+ end_ps.append(idx)
62
+ start_cs.append(0)
63
+ end_cs.append(len(p))
64
+
65
+ return {"wikipedia_id": ids, "title": titles,
66
+ "section": sections, "text": texts,
67
+ "start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
68
+ "start_character": start_cs,
69
+ "end_character": end_cs
70
+ }
71
+
72
+
73
+ def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
74
+ res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
75
+ res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
76
+
77
+ # make a KILT data point
78
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
79
+ output = []
80
+ for a in eli5_example["answers"]["text"]:
81
+ output.append({"answer": a})
82
+
83
+ output.append({"provenance": [
84
+ # evidence set for the answer from the KILT ks
85
+ {
86
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
87
+ "title": r["title"],
88
+ "section": r["section"],
89
+ "start_paragraph_id": r["start_paragraph_id"],
90
+ "start_character": r["start_character"],
91
+ "end_paragraph_id": r["end_paragraph_id"],
92
+ "end_character": r["end_character"],
93
+ "text": r["text"],
94
+ "bleu_score": None, # wrt original evidence
95
+ "meta": None # dataset/task specific
96
+ } for r in res_list
97
+ ]})
98
+ return {"id": eli5_example["q_id"],
99
+ "input": eli5_example["title"],
100
+ "output": output, # each element is an answer or provenance (can have multiple of each)
101
+ "meta": None # dataset/task specific
102
+ }
103
+
104
+
105
+ def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
106
+ query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
107
+ return_tensors="pt")
108
+ with torch.no_grad():
109
+ q_reps = question_model(query["input_ids"].to(device),
110
+ query["attention_mask"].to(device)).pooler_output
111
+ return q_reps.cpu().numpy()
112
+
113
+
114
+ def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
115
+ p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
116
+ truncation=True, return_tensors="pt")
117
+ with torch.no_grad():
118
+ a_reps = ctx_model(p["input_ids"].to(device),
119
+ p["attention_mask"].to(device)).pooler_output
120
+ return {"embeddings": a_reps.cpu().numpy()}
util/create_dpr_training_from_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import json
4
+ import re
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.util import semantic_search, cos_sim
8
+ from tqdm.auto import tqdm
9
+ from datasets import load_dataset
10
+
11
+ from common import clean_answer, clean_question
12
+
13
+
14
+ def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
15
+ exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
16
+ hard_negative_ctxs = []
17
+ results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
18
+ score_function=cos_sim)
19
+ # list if dicts
20
+ # [{'corpus_id': 8, 'score': -0.019427383318543434},
21
+ # ...
22
+ # {'corpus_id': 10, 'score': -0.09040290117263794}]
23
+ # hard negative are most similar and negatives are most disimilar to embedding_index
24
+ hard_negative_results = results[0][1:k + 1]
25
+ assert len(hard_negative_results) > min_count * 2
26
+ for r in hard_negative_results:
27
+ example = dataset[r["corpus_id"]]
28
+ if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
29
+ for a in example["answers"]["text"]:
30
+ hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
31
+ if len(hard_negative_ctxs) > min_count:
32
+ break
33
+ return hard_negative_ctxs[:min_count]
34
+
35
+
36
+ def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
37
+ exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
38
+ negative_ctxs = []
39
+ random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
40
+ similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
41
+ for idx, score in enumerate(similarities):
42
+ if score < similarity_threshold:
43
+ example = dataset[random_sample[idx]]
44
+ for a in example["answers"]["text"]:
45
+ negative_ctxs.append({"title": "", "text": clean_answer(a)})
46
+ if len(negative_ctxs) > min_count:
47
+ break
48
+ return negative_ctxs[:min_count]
49
+
50
+
51
+ def generate_dpr_training_file(args):
52
+ embedder = SentenceTransformer(args.embedding_model)
53
+
54
+ eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
55
+ eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
56
+ eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
57
+
58
+ train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
59
+ show_progress_bar=True)
60
+ validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
61
+ show_progress_bar=True)
62
+
63
+ test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
64
+ show_progress_bar=True)
65
+ exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
66
+ for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
67
+ [eli5_train_set, eli5_validation_set, eli5_test_set],
68
+ [train_set, validation_set, test_set]):
69
+ min_elements = 3
70
+ skip_count = 0
71
+ progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
72
+ with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
73
+ for idx, example in enumerate(dataset):
74
+ negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
75
+ hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
76
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
77
+ not any([p.search(a) for p in exclude_answer_patterns])]
78
+ if not positive_context:
79
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
80
+ if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
81
+ json.dump({"id": example["q_id"],
82
+ "question": clean_question(example["title"]),
83
+ "positive_ctxs": positive_context[:min_elements],
84
+ "negative_ctxs": negative_ctxs[:min_elements],
85
+ "hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
86
+ fp.write("\n")
87
+ else:
88
+ skip_count += 1
89
+ progress_bar.update(1)
90
+
91
+ print(f"Skipped {skip_count} questions")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
96
+ parser.add_argument(
97
+ "--embedding_model",
98
+ default="all-mpnet-base-v2",
99
+ help="Embedding model to use for question encoding and semantic search",
100
+ )
101
+
102
+ main_args, _ = parser.parse_known_args()
103
+ generate_dpr_training_file(main_args)
util/create_dpr_training_from_faiss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from transformers import DPRQuestionEncoder
9
+
10
+ from common import embed_questions, clean_question, articles_to_paragraphs, kilt_wikipedia_columns
11
+ from common import kilt_wikipedia_paragraph_columns as columns
12
+
13
+
14
+ def generate_dpr_training_file(args):
15
+ n_negatives = 7
16
+ min_chars_per_passage = 200
17
+
18
+ def query_index(question, topk=(n_negatives * args.n_positives) * 2):
19
+ question_embedding = embed_questions(question_model, question_tokenizer, [question])
20
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
21
+
22
+ retrieved_examples = []
23
+ r = list(zip(wiki_passages[k] for k in columns))
24
+ for i in range(topk):
25
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
26
+
27
+ return retrieved_examples
28
+
29
+ def find_positive_and_hard_negative_ctxs(dataset_index: int, n_positive=1, device="cuda:0"):
30
+ positive_context_list = []
31
+ hard_negative_context_list = []
32
+ example = dataset[dataset_index]
33
+ question = clean_question(example['title'])
34
+ passages = query_index(question)
35
+ passages = [dict([(k, p[k]) for k in columns]) for p in passages]
36
+ q_passage_pairs = [[question, f"{p['title']} {p['text']}" if args.use_title else p["text"]] for p in passages]
37
+
38
+ features = ce_tokenizer(q_passage_pairs, padding="max_length", max_length=256, truncation=True,
39
+ return_tensors="pt")
40
+ with torch.no_grad():
41
+ passage_scores = ce_model(features["input_ids"].to(device),
42
+ features["attention_mask"].to(device)).logits
43
+
44
+ for p_idx, p in enumerate(passages):
45
+ p["score"] = passage_scores[p_idx].item()
46
+
47
+ # order by scores
48
+ def score_passage(item):
49
+ return item["score"]
50
+
51
+ # pick the most relevant as the positive answer
52
+ best_passage_list = sorted(passages, key=score_passage, reverse=True)
53
+ for idx, item in enumerate(best_passage_list):
54
+ if idx < n_positive:
55
+ positive_context_list.append({"title": item["title"], "text": item["text"]})
56
+ else:
57
+ break
58
+
59
+ # least relevant as hard_negative
60
+ worst_passage_list = sorted(passages, key=score_passage, reverse=False)
61
+ for idx, hard_negative in enumerate(worst_passage_list):
62
+ if idx < n_negatives * n_positive:
63
+ hard_negative_context_list.append({"title": hard_negative["title"], "text": hard_negative["text"]})
64
+ else:
65
+ break
66
+ assert len(positive_context_list) * n_negatives == len(hard_negative_context_list)
67
+ return positive_context_list, hard_negative_context_list
68
+
69
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
72
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
73
+ _ = question_model.eval()
74
+
75
+ ce_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2').to(device)
76
+ ce_tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2')
77
+ _ = ce_model.eval()
78
+
79
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
80
+
81
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
82
+ remove_columns=kilt_wikipedia_columns,
83
+ batch_size=512,
84
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
85
+ desc="Expanding wiki articles into paragraphs")
86
+
87
+ # use paragraphs that are not simple fragments or very short sentences
88
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
89
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
90
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
91
+
92
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
93
+
94
+ eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
95
+ eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
96
+ eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
97
+
98
+ for dataset_name, dataset in zip(["train", "validation", "test"], [eli5_train_set,
99
+ eli5_validation_set,
100
+ eli5_test_set]):
101
+
102
+ progress_bar = tqdm(range(len(dataset)), desc=f"Creating DPR formatted {dataset_name} file")
103
+ with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
104
+ for idx, example in enumerate(dataset):
105
+ negative_start_idx = 0
106
+ positive_context, hard_negative_ctxs = find_positive_and_hard_negative_ctxs(idx, args.n_positives,
107
+ device)
108
+ for pc in positive_context:
109
+ hnc = hard_negative_ctxs[negative_start_idx:negative_start_idx + n_negatives]
110
+ json.dump({"id": example["q_id"],
111
+ "question": clean_question(example["title"]),
112
+ "positive_ctxs": [pc],
113
+ "hard_negative_ctxs": hnc}, fp)
114
+ fp.write("\n")
115
+ negative_start_idx += n_negatives
116
+ progress_bar.update(1)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ parser = argparse.ArgumentParser(description="Creates DPR training file")
121
+ parser.add_argument(
122
+ "--use_title",
123
+ action="store_true",
124
+ help="If true, use title in addition to passage text for passage embedding",
125
+ )
126
+ parser.add_argument(
127
+ "--n_positives",
128
+ default=3,
129
+ help="Number of positive samples per question",
130
+ )
131
+ parser.add_argument(
132
+ "--question_encoder_name",
133
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
134
+ help="Question encoder to use",
135
+ )
136
+
137
+ parser.add_argument(
138
+ "--index_file_name",
139
+ default="../data/kilt_dpr_wikipedia_first.faiss",
140
+ help="Faiss index with passage embeddings",
141
+ )
142
+
143
+ main_args, _ = parser.parse_known_args()
144
+ generate_dpr_training_file(main_args)
util/create_faiss_index.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import faiss
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer, DPRContextEncoder
8
+
9
+ from common import articles_to_paragraphs, embed_passages
10
+
11
+
12
+ def create_faiss(args):
13
+ dims = 128
14
+ min_chars_per_passage = 200
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
18
+ ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
19
+ _ = ctx_model.eval()
20
+
21
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
22
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
23
+ 'wikidata_info', 'history']
24
+
25
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
26
+ remove_columns=kilt_wikipedia_columns,
27
+ batch_size=512,
28
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
29
+ desc="Expanding wiki articles into paragraphs")
30
+
31
+ # use paragraphs that are not simple fragments or very short sentences
32
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
33
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
34
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
35
+
36
+ if not os.path.isfile(args.index_file_name):
37
+ def embed_passages_for_retrieval(examples):
38
+ return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
39
+
40
+ paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
41
+ batched=True, batch_size=512,
42
+ cache_file_name="../data/kilt_embedded.arrow",
43
+ desc="Creating faiss index")
44
+
45
+ paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
46
+ paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
47
+ else:
48
+ print(f"Faiss index already exists {args.index_file_name}")
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
53
+
54
+ parser.add_argument(
55
+ "--ctx_encoder_name",
56
+ default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
57
+ help="Encoding model to use for passage encoding",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--index_file_name",
62
+ default="../data/kilt_dpr_wikipedia.faiss",
63
+ help="Faiss index file with passage embeddings",
64
+ )
65
+
66
+ main_args, _ = parser.parse_known_args()
67
+ create_faiss(main_args)
util/eval_generate.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from datasets import load_dataset
7
+ from tqdm.auto import tqdm
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder
9
+
10
+ from common import articles_to_paragraphs, kilt_wikipedia_columns
11
+ from common import kilt_wikipedia_paragraph_columns as columns
12
+
13
+
14
+ def eval_generate(args):
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
17
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
18
+ _ = question_model.eval()
19
+
20
+ eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
21
+ eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
22
+ _ = eli5_model.eval()
23
+
24
+ min_snippet_length = 20
25
+ topk = 21
26
+ min_chars_per_passage = 200
27
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
28
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
29
+ remove_columns=kilt_wikipedia_columns,
30
+ batch_size=256,
31
+ cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
32
+ desc="Expanding wiki articles into paragraphs")
33
+
34
+ # use paragraphs that are not simple fragments or very short sentences
35
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
36
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
37
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
38
+
39
+ def embed_questions_for_retrieval(questions):
40
+ query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
41
+ with torch.no_grad():
42
+ q_reps = question_model(query["input_ids"].to(device),
43
+ query["attention_mask"].to(device)).pooler_output
44
+ return q_reps.cpu().numpy()
45
+
46
+ def query_index(question):
47
+ question_embedding = embed_questions_for_retrieval([question])
48
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
49
+
50
+ retrieved_examples = []
51
+ r = list(zip(wiki_passages[k] for k in columns))
52
+ for i in range(topk):
53
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
54
+ return retrieved_examples
55
+
56
+ def create_kilt_datapoint(q_id, query, answer, res_list):
57
+ # make a KILT data point
58
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
59
+
60
+ provenance = [{
61
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
62
+ "title": r["title"],
63
+ "section": r["section"],
64
+ "start_paragraph_id": r["start_paragraph_id"],
65
+ "start_character": r["start_character"],
66
+ "end_paragraph_id": r["end_paragraph_id"],
67
+ "end_character": r["end_character"],
68
+ "text": r["text"],
69
+ "bleu_score": None, # wrt original evidence
70
+ "meta": None # dataset/task specific
71
+ } for r in res_list]
72
+
73
+ output = [{"answer": answer, "provenance": provenance}]
74
+
75
+ return {"id": q_id,
76
+ "input": query,
77
+ "output": output, # each element is an answer or provenance (can have multiple of each)
78
+ "meta": None # dataset/task specific
79
+ }
80
+
81
+ kilt_output = []
82
+ with open(args.kilt_input_file, "r") as f:
83
+ kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
84
+ progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
85
+ for idx, item in enumerate(kilt_items):
86
+ query = item["input"]
87
+ res_list = query_index(query)
88
+
89
+ res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
90
+ documents = [res["text"] for res in res_list]
91
+ conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
92
+
93
+ query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
94
+
95
+ model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
96
+ generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
97
+ attention_mask=model_input["attention_mask"].to(device),
98
+ min_length=50,
99
+ max_length=250,
100
+ do_sample=False,
101
+ early_stopping=True,
102
+ num_beams=8,
103
+ temperature=1.0,
104
+ top_k=None,
105
+ top_p=None,
106
+ no_repeat_ngram_size=3,
107
+ num_return_sequences=1)
108
+ answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=True)
110
+
111
+ kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
112
+ kilt_output.append(kilt_example)
113
+ progress_bar.update(1)
114
+
115
+ with open(args.kilt_output_file, "w") as fp:
116
+ for kilt_example in kilt_output:
117
+ json.dump(kilt_example, fp)
118
+ fp.write("\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
124
+ parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
125
+ parser.add_argument(
126
+ "--question_encoder_name",
127
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
128
+ help="Question encoder to use",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--index_file_name",
133
+ default="../data/kilt_dpr_wikipedia_first.faiss",
134
+ help="Faiss index with passage embeddings",
135
+ )
136
+
137
+ args = parser.parse_args()
138
+
139
+ assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
140
+ eval_generate(args)
util/kilt_create_dpr_support_docs.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import faiss
6
+ import torch
7
+ from datasets import load_dataset, Dataset
8
+ from tqdm.auto import tqdm
9
+ from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
10
+
11
+ from common import articles_to_paragraphs, embed_questions, embed_passages, create_kilt_datapoint, \
12
+ kilt_wikipedia_columns
13
+ from common import kilt_wikipedia_paragraph_columns as columns
14
+
15
+
16
+ def generate_support_docs(args):
17
+ dims = 128
18
+ min_chars_per_passage = 200
19
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
20
+ lfqa = load_dataset("vblagoje/lfqa")
21
+
22
+ ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
23
+ ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
24
+ _ = ctx_model.eval()
25
+
26
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
27
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
28
+ _ = question_model.eval()
29
+
30
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
31
+
32
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
33
+ remove_columns=kilt_wikipedia_columns,
34
+ batch_size=512,
35
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
36
+ desc="Expanding wiki articles into paragraphs")
37
+
38
+ # use paragraphs that are not simple fragments or very short sentences
39
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
40
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
41
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
42
+
43
+ def query_index(question, topk=7):
44
+ topk = topk * 3 # grab 3x results and filter for word count
45
+ question_embedding = embed_questions(question_model, question_tokenizer, [question])
46
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
47
+
48
+ retrieved_examples = []
49
+ r = list(zip(wiki_passages[k] for k in columns))
50
+ for i in range(topk):
51
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
52
+
53
+ return retrieved_examples
54
+
55
+ def create_support_doc(dataset: Dataset, output_filename: str):
56
+ progress_bar = tqdm(range(len(dataset)), desc="Creating supporting docs")
57
+
58
+ with open(output_filename, "w") as fp:
59
+ for example in dataset:
60
+ wiki_passages = query_index(example["title"])
61
+ kilt_dp = create_kilt_datapoint(example, columns, wiki_passages)
62
+ json.dump(kilt_dp, fp)
63
+ fp.write("\n")
64
+ progress_bar.update(1)
65
+
66
+ if not os.path.isfile(args.index_file_name):
67
+ def embed_passages_for_retrieval(examples):
68
+ return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
69
+
70
+ paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
71
+ batched=True, batch_size=512,
72
+ cache_file_name=args.encoded_kilt_file_name,
73
+ desc="Creating faiss index")
74
+
75
+ paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
76
+ paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
77
+
78
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
79
+ create_support_doc(lfqa["train"], "lfqa_dpr_train_precomputed_dense_docs.json")
80
+ create_support_doc(lfqa["validation"], "lfqa_dpr_validation_precomputed_dense_docs.json")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser(description="Creates support docs for seq2seq model training")
85
+ parser.add_argument(
86
+ "--ctx_encoder_name",
87
+ default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
88
+ help="Question encoder to use",
89
+ )
90
+ parser.add_argument(
91
+ "--question_encoder_name",
92
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
93
+ help="Question encoder to use",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--index_file_name",
98
+ default="../data/kilt_dpr_wikipedia_first.faiss",
99
+ help="Faiss index with passage embeddings",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--encoded_kilt_file_name",
104
+ default="../data/kilt_embedded.arrow",
105
+ help="Encoded KILT file name",
106
+ )
107
+
108
+ main_args, _ = parser.parse_known_args()
109
+ generate_support_docs(main_args)
util/query_smoke_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+
4
+ from datasets import load_dataset
5
+
6
+
7
+ def main():
8
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
9
+ tokenizer = AutoTokenizer.from_pretrained('vblagoje/retribert-base-uncased')
10
+ model = AutoModel.from_pretrained('vblagoje/retribert-base-uncased').to(device)
11
+ _ = model.eval()
12
+
13
+ index_file_name = "./data/kilt_wikipedia.faiss"
14
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
15
+ columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
16
+ 'wikidata_info', 'history']
17
+
18
+ min_snippet_length = 20
19
+ topk = 21
20
+
21
+ def articles_to_paragraphs(examples):
22
+ ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
23
+ for bidx, example in enumerate(examples["text"]):
24
+ last_section = ""
25
+ for idx, p in enumerate(example["paragraph"]):
26
+ if "Section::::" in p:
27
+ last_section = p
28
+ ids.append(examples["wikipedia_id"][bidx])
29
+ titles.append(examples["wikipedia_title"][bidx])
30
+ sections.append(last_section)
31
+ texts.append(p)
32
+ start_ps.append(idx)
33
+ end_ps.append(idx)
34
+ start_cs.append(0)
35
+ end_cs.append(len(p))
36
+
37
+ return {"wikipedia_id": ids, "title": titles,
38
+ "section": sections, "text": texts,
39
+ "start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
40
+ "start_character": start_cs,
41
+ "end_character": end_cs
42
+ }
43
+
44
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
45
+ remove_columns=columns,
46
+ batch_size=256, cache_file_name=f"./wiki_kilt_paragraphs_full.arrow",
47
+ desc="Expanding wiki articles into paragraphs")
48
+
49
+ # use paragraphs that are not simple fragments or very short sentences
50
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 250)
51
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
52
+
53
+ def embed_questions_for_retrieval(questions):
54
+ query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
55
+ with torch.no_grad():
56
+ q_reps = model.embed_questions(query["input_ids"].to(device),
57
+ query["attention_mask"].to(device)).cpu().type(torch.float)
58
+ return q_reps.numpy()
59
+
60
+ def query_index(question):
61
+ question_embedding = embed_questions_for_retrieval([question])
62
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
63
+ columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id', 'start_character','end_character']
64
+ retrieved_examples = []
65
+ r = list(zip(wiki_passages[k] for k in columns))
66
+ for i in range(topk):
67
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
68
+ return retrieved_examples
69
+
70
+ questions = ["What causes the contrails (cirrus aviaticus) behind jets at high altitude? ",
71
+ "Why does water heated to a room temeperature feel colder than the air around it?"]
72
+ res_list = query_index(questions[0])
73
+ res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
74
+ for res in res_list:
75
+ print("\n")
76
+ print(res)
77
+
78
+
79
+ main()
80
+
81
+
wikipedia_answer.png ADDED
wikipedia_context.png ADDED