Spaces:
Runtime error
Runtime error
Commit
·
b6bff08
0
Parent(s):
Duplicate from h2oai/h2ogpt-chatbot
Browse filesCo-authored-by: Jonathan McKinney <pseudotensor@users.noreply.huggingface.co>
- .gitattributes +34 -0
- LICENSE +201 -0
- README.md +14 -0
- app.py +1959 -0
- client_test.py +93 -0
- finetune.py +934 -0
- h2o-logo.svg +1 -0
- prompter.py +106 -0
- requirements.txt +48 -0
- stopping.py +139 -0
- utils.py +154 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: H2ogpt Chatbot
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: h2oai/h2ogpt-chatbot
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,1959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import traceback
|
6 |
+
import typing
|
7 |
+
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
|
8 |
+
|
9 |
+
SEED = 1236
|
10 |
+
set_seed(SEED)
|
11 |
+
|
12 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
13 |
+
from typing import Union
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
import fire
|
18 |
+
import torch
|
19 |
+
from peft import PeftModel
|
20 |
+
from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
|
21 |
+
from accelerate import init_empty_weights, infer_auto_device_map
|
22 |
+
|
23 |
+
from prompter import Prompter
|
24 |
+
|
25 |
+
from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
|
26 |
+
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
27 |
+
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
28 |
+
|
29 |
+
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
30 |
+
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
31 |
+
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
32 |
+
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
33 |
+
admin_pass = os.getenv("ADMIN_PASS")
|
34 |
+
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
35 |
+
raise_generate_gpu_exceptions = True
|
36 |
+
|
37 |
+
eval_extra_columns = ['prompt', 'response', 'score']
|
38 |
+
|
39 |
+
def main(
|
40 |
+
load_8bit: bool = False,
|
41 |
+
load_half: bool = True,
|
42 |
+
infer_devices: bool = True,
|
43 |
+
base_model: str = '',
|
44 |
+
tokenizer_base_model: str = '',
|
45 |
+
lora_weights: str = "",
|
46 |
+
gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
|
47 |
+
|
48 |
+
prompt_type: Union[int, str] = None,
|
49 |
+
# input to generation
|
50 |
+
temperature: float = None,
|
51 |
+
top_p: float = None,
|
52 |
+
top_k: int = None,
|
53 |
+
num_beams: int = None,
|
54 |
+
repetition_penalty: float = None,
|
55 |
+
num_return_sequences: int = None,
|
56 |
+
do_sample: bool = None,
|
57 |
+
max_new_tokens: int = None,
|
58 |
+
min_new_tokens: int = None,
|
59 |
+
early_stopping: Union[bool, str] = None,
|
60 |
+
max_time: float = None,
|
61 |
+
|
62 |
+
llama_type: bool = None,
|
63 |
+
debug: bool = False,
|
64 |
+
save_dir: str = None,
|
65 |
+
share: bool = True,
|
66 |
+
local_files_only: bool = False,
|
67 |
+
resume_download: bool = True,
|
68 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
69 |
+
|
70 |
+
src_lang: str = "English",
|
71 |
+
tgt_lang: str = "Russian",
|
72 |
+
|
73 |
+
gradio: bool = True,
|
74 |
+
gradio_avoid_processing_markdown: bool = False,
|
75 |
+
chat: bool = True,
|
76 |
+
chat_history: int = 4096, # character length of chat context/history
|
77 |
+
stream_output: bool = True,
|
78 |
+
show_examples: bool = None,
|
79 |
+
verbose: bool = False,
|
80 |
+
h2ocolors: bool = True,
|
81 |
+
height: int = 400,
|
82 |
+
show_lora: bool = True,
|
83 |
+
# set to True to load --base_model after client logs in,
|
84 |
+
# to be able to free GPU memory when model is swapped
|
85 |
+
login_mode_if_model0: bool = False,
|
86 |
+
|
87 |
+
sanitize_user_prompt: bool = True,
|
88 |
+
sanitize_bot_response: bool = True,
|
89 |
+
|
90 |
+
extra_model_options: typing.List[str] = [],
|
91 |
+
extra_lora_options: typing.List[str] = [],
|
92 |
+
|
93 |
+
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
94 |
+
auto_score: bool = True,
|
95 |
+
|
96 |
+
eval_sharegpt_prompts_only: int = 0,
|
97 |
+
eval_sharegpt_prompts_only_seed: int = 1234,
|
98 |
+
eval_sharegpt_as_output: bool = False,
|
99 |
+
):
|
100 |
+
# allow set token directly
|
101 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
102 |
+
|
103 |
+
if is_public:
|
104 |
+
temperature = 0.4
|
105 |
+
top_p = 0.85
|
106 |
+
top_k = 70
|
107 |
+
do_sample = True
|
108 |
+
if is_low_mem:
|
109 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
110 |
+
load_8bit = True
|
111 |
+
else:
|
112 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
113 |
+
if is_low_mem:
|
114 |
+
load_8bit = True
|
115 |
+
if is_hf:
|
116 |
+
# must override share if in spaces
|
117 |
+
share = False
|
118 |
+
save_dir = os.getenv('SAVE_DIR', save_dir)
|
119 |
+
|
120 |
+
# get defaults
|
121 |
+
model_lower = base_model.lower()
|
122 |
+
if not gradio:
|
123 |
+
# force, else not single response like want to look at
|
124 |
+
stream_output = False
|
125 |
+
# else prompt removal can mess up output
|
126 |
+
chat = False
|
127 |
+
|
128 |
+
placeholder_instruction, placeholder_input, \
|
129 |
+
stream_output, show_examples, \
|
130 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
131 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
132 |
+
repetition_penalty, num_return_sequences, \
|
133 |
+
do_sample, \
|
134 |
+
src_lang, tgt_lang, \
|
135 |
+
examples, \
|
136 |
+
task_info = \
|
137 |
+
get_generate_params(model_lower, chat,
|
138 |
+
stream_output, show_examples,
|
139 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
140 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
141 |
+
repetition_penalty, num_return_sequences,
|
142 |
+
do_sample,
|
143 |
+
)
|
144 |
+
|
145 |
+
if not gradio:
|
146 |
+
if eval_sharegpt_prompts_only > 0:
|
147 |
+
# override default examples with shareGPT ones for human-level eval purposes only
|
148 |
+
eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
|
149 |
+
if not os.path.isfile(eval_filename):
|
150 |
+
os.system(
|
151 |
+
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
|
152 |
+
import json
|
153 |
+
data = json.load(open(eval_filename, 'rt'))
|
154 |
+
# focus on data that starts with human, else likely chopped from other data
|
155 |
+
turn_start = 0 # odd in general
|
156 |
+
data = [x for x in data if len(x['conversations']) > turn_start + 1 and
|
157 |
+
x['conversations'][turn_start]['from'] == 'human' and
|
158 |
+
x['conversations'][turn_start + 1]['from'] == 'gpt']
|
159 |
+
np.random.seed(eval_sharegpt_prompts_only_seed)
|
160 |
+
example1 = examples[-1] # pick reference example
|
161 |
+
examples = []
|
162 |
+
responses = []
|
163 |
+
for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
|
164 |
+
assert data[i]['conversations'][turn_start]['from'] == 'human'
|
165 |
+
instruction = data[i]['conversations'][turn_start]['value']
|
166 |
+
assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
|
167 |
+
output = data[i]['conversations'][turn_start + 1]['value']
|
168 |
+
examplenew = example1.copy()
|
169 |
+
assert not chat, "No gradio must use chat=False, uses nochat isntruct"
|
170 |
+
examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
|
171 |
+
examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
|
172 |
+
examplenew[eval_func_param_names.index('context')] = '' # no context
|
173 |
+
examples.append(examplenew)
|
174 |
+
responses.append(output)
|
175 |
+
|
176 |
+
num_examples = len(examples)
|
177 |
+
scoring_path = 'scoring'
|
178 |
+
os.makedirs(scoring_path, exist_ok=True)
|
179 |
+
if eval_sharegpt_as_output:
|
180 |
+
used_base_model = 'gpt35'
|
181 |
+
used_lora_weights = ''
|
182 |
+
else:
|
183 |
+
used_base_model = str(base_model.split('/')[-1])
|
184 |
+
used_lora_weights = str(lora_weights.split('/')[-1])
|
185 |
+
eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
186 |
+
eval_sharegpt_prompts_only_seed,
|
187 |
+
eval_sharegpt_as_output,
|
188 |
+
used_base_model,
|
189 |
+
used_lora_weights)
|
190 |
+
eval_filename = os.path.join(scoring_path, eval_filename)
|
191 |
+
|
192 |
+
with torch.device("cuda"):
|
193 |
+
# ensure was set right above before examples generated
|
194 |
+
assert not stream_output, "stream_output=True does not make sense with example loop"
|
195 |
+
import time
|
196 |
+
from functools import partial
|
197 |
+
|
198 |
+
# get score model
|
199 |
+
smodel, stokenizer, sdevice = get_score_model(**locals())
|
200 |
+
|
201 |
+
if not eval_sharegpt_as_output:
|
202 |
+
model, tokenizer, device = get_model(**locals())
|
203 |
+
model_state = [model, tokenizer, device, base_model]
|
204 |
+
fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir)
|
205 |
+
else:
|
206 |
+
assert eval_sharegpt_prompts_only > 0
|
207 |
+
|
208 |
+
def get_response(*args, exi=0):
|
209 |
+
# assumes same ordering of examples and responses
|
210 |
+
yield responses[exi]
|
211 |
+
|
212 |
+
fun = get_response
|
213 |
+
t0 = time.time()
|
214 |
+
score_dump = []
|
215 |
+
|
216 |
+
import matplotlib.pyplot as plt
|
217 |
+
|
218 |
+
for exi, ex in enumerate(examples):
|
219 |
+
instruction = ex[eval_func_param_names.index('instruction_nochat')]
|
220 |
+
iinput = ex[eval_func_param_names.index('iinput_nochat')]
|
221 |
+
context = ex[eval_func_param_names.index('context')]
|
222 |
+
clear_torch_cache()
|
223 |
+
print("")
|
224 |
+
print("START" + "=" * 100)
|
225 |
+
print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
|
226 |
+
print("-" * 105)
|
227 |
+
# fun yields as generator, so have to iterate over it
|
228 |
+
# Also means likely do NOT want --stream_output=True, else would show all generations
|
229 |
+
for res in fun(*tuple(ex), exi=exi):
|
230 |
+
print(res)
|
231 |
+
if smodel:
|
232 |
+
score_with_prompt = False
|
233 |
+
if score_with_prompt:
|
234 |
+
data_point = dict(instruction=instruction, input=iinput, context=context)
|
235 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
236 |
+
prompt = prompter.generate_prompt(data_point)
|
237 |
+
else:
|
238 |
+
# just raw input and output
|
239 |
+
assert iinput in [None, ''] # should be no iinput
|
240 |
+
assert context in [None, ''] # should be no context
|
241 |
+
prompt = instruction
|
242 |
+
cutoff_len = 768 if is_low_mem else 2048
|
243 |
+
inputs = stokenizer(prompt, res,
|
244 |
+
return_tensors="pt",
|
245 |
+
truncation=True,
|
246 |
+
max_length=cutoff_len)
|
247 |
+
try:
|
248 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
249 |
+
except torch.cuda.OutOfMemoryError as e:
|
250 |
+
print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
|
251 |
+
traceback.print_exc()
|
252 |
+
score = 0.0
|
253 |
+
clear_torch_cache()
|
254 |
+
except (Exception, RuntimeError) as e:
|
255 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
256 |
+
'expected scalar type Half but found Float' in str(e) or \
|
257 |
+
'probability tensor contains either' in str(e) or \
|
258 |
+
'cublasLt ran into an error!' in str(e):
|
259 |
+
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
260 |
+
flush=True)
|
261 |
+
traceback.print_exc()
|
262 |
+
score = 0.0
|
263 |
+
clear_torch_cache()
|
264 |
+
else:
|
265 |
+
raise
|
266 |
+
print("SCORE %s: %s" % (exi, score), flush=True)
|
267 |
+
score_dump.append(ex + [prompt, res, score])
|
268 |
+
# dump every score in case abort
|
269 |
+
df_scores = pd.DataFrame(score_dump,
|
270 |
+
columns=eval_func_param_names + eval_extra_columns)
|
271 |
+
df_scores.to_parquet(eval_filename, index=False)
|
272 |
+
# plot histogram so far
|
273 |
+
plt.figure(figsize=(10, 10))
|
274 |
+
plt.hist(df_scores['score'], bins=20)
|
275 |
+
score_avg = np.mean(df_scores['score'])
|
276 |
+
score_median = np.median(df_scores['score'])
|
277 |
+
plt.title("Score avg: %s median: %s" % (score_avg, score_median))
|
278 |
+
plt.savefig(eval_filename.replace('.parquet', '.png'))
|
279 |
+
plt.close()
|
280 |
+
|
281 |
+
print("END" + "=" * 102)
|
282 |
+
print("")
|
283 |
+
t2 = time.time()
|
284 |
+
print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
|
285 |
+
t1 = time.time()
|
286 |
+
print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
|
287 |
+
return eval_filename
|
288 |
+
|
289 |
+
if gradio:
|
290 |
+
go_gradio(**locals())
|
291 |
+
|
292 |
+
|
293 |
+
def get_device():
|
294 |
+
if torch.cuda.is_available():
|
295 |
+
device = "cuda"
|
296 |
+
else:
|
297 |
+
raise RuntimeError("only cuda supported")
|
298 |
+
|
299 |
+
return device
|
300 |
+
|
301 |
+
|
302 |
+
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
303 |
+
gpu_id=0,
|
304 |
+
use_auth_token=False):
|
305 |
+
"""
|
306 |
+
Ensure model gets on correct device
|
307 |
+
:param base_model:
|
308 |
+
:param model_loader:
|
309 |
+
:param load_half:
|
310 |
+
:param model_kwargs:
|
311 |
+
:param reward_type:
|
312 |
+
:param gpu_id:
|
313 |
+
:param use_auth_token:
|
314 |
+
:return:
|
315 |
+
"""
|
316 |
+
with init_empty_weights():
|
317 |
+
from transformers import AutoConfig
|
318 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
|
319 |
+
model = AutoModel.from_config(
|
320 |
+
config,
|
321 |
+
)
|
322 |
+
|
323 |
+
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
324 |
+
# NOTE: Some models require avoiding sharding some layers,
|
325 |
+
# then would pass no_split_module_classes and give list of those layers.
|
326 |
+
device_map = infer_auto_device_map(
|
327 |
+
model,
|
328 |
+
dtype=torch.float16 if load_half else torch.float32,
|
329 |
+
)
|
330 |
+
if hasattr(model, 'model'):
|
331 |
+
device_map_model = infer_auto_device_map(
|
332 |
+
model.model,
|
333 |
+
dtype=torch.float16 if load_half else torch.float32,
|
334 |
+
)
|
335 |
+
device_map.update(device_map_model)
|
336 |
+
print('device_map: %s' % device_map, flush=True)
|
337 |
+
|
338 |
+
if gpu_id >= 0:
|
339 |
+
# FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
|
340 |
+
# So avoid for now, just put on first GPU, unless score_model, put on last
|
341 |
+
n_gpus = torch.cuda.device_count()
|
342 |
+
if reward_type:
|
343 |
+
device_map = {'': n_gpus - 1}
|
344 |
+
else:
|
345 |
+
device_map = {'': min(n_gpus - 1, gpu_id)}
|
346 |
+
|
347 |
+
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
348 |
+
model_kwargs['device_map'] = device_map
|
349 |
+
|
350 |
+
if load_in_8bit or not load_half:
|
351 |
+
model = model_loader.from_pretrained(
|
352 |
+
base_model,
|
353 |
+
**model_kwargs,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
model = model_loader.from_pretrained(
|
357 |
+
base_model,
|
358 |
+
**model_kwargs,
|
359 |
+
).half()
|
360 |
+
return model
|
361 |
+
|
362 |
+
|
363 |
+
def get_model(
|
364 |
+
load_8bit: bool = False,
|
365 |
+
load_half: bool = True,
|
366 |
+
infer_devices: bool = True,
|
367 |
+
base_model: str = '',
|
368 |
+
tokenizer_base_model: str = '',
|
369 |
+
lora_weights: str = "",
|
370 |
+
gpu_id: int = 0,
|
371 |
+
|
372 |
+
llama_type: bool = None,
|
373 |
+
reward_type: bool = None,
|
374 |
+
local_files_only: bool = False,
|
375 |
+
resume_download: bool = True,
|
376 |
+
use_auth_token: Union[str, bool] = False,
|
377 |
+
compile: bool = True,
|
378 |
+
**kwargs,
|
379 |
+
):
|
380 |
+
"""
|
381 |
+
|
382 |
+
:param load_8bit: load model in 8-bit, not supported by all models
|
383 |
+
:param load_half: load model in 16-bit
|
384 |
+
:param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
385 |
+
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
386 |
+
So it is not the default
|
387 |
+
:param base_model: name/path of base model
|
388 |
+
:param tokenizer_base_model: name/path of tokenizer
|
389 |
+
:param lora_weights: name/path
|
390 |
+
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
|
391 |
+
:param llama_type: whether LLaMa type model
|
392 |
+
:param reward_type: reward type model for sequence classification
|
393 |
+
:param local_files_only: use local files instead of from HF
|
394 |
+
:param resume_download: resume downloads from HF
|
395 |
+
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
396 |
+
:parm compile: whether to compile torch model
|
397 |
+
:param kwargs:
|
398 |
+
:return:
|
399 |
+
"""
|
400 |
+
print("Get %s model" % base_model, flush=True)
|
401 |
+
if lora_weights is not None and lora_weights.strip():
|
402 |
+
print("Get %s lora weights" % lora_weights, flush=True)
|
403 |
+
device = get_device()
|
404 |
+
|
405 |
+
if 'gpt2' in base_model.lower():
|
406 |
+
# RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
|
407 |
+
load_8bit = False
|
408 |
+
|
409 |
+
assert base_model.strip(), (
|
410 |
+
"Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
|
411 |
+
)
|
412 |
+
llama_type = llama_type or "llama" in base_model
|
413 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
|
414 |
+
if not tokenizer_base_model:
|
415 |
+
tokenizer_base_model = base_model
|
416 |
+
|
417 |
+
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
418 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
419 |
+
local_files_only=local_files_only,
|
420 |
+
resume_download=resume_download,
|
421 |
+
use_auth_token=use_auth_token,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
tokenizer = tokenizer_loader
|
425 |
+
|
426 |
+
if isinstance(tokenizer, str):
|
427 |
+
# already a pipeline, tokenizer_loader is string for task
|
428 |
+
model = model_loader(tokenizer,
|
429 |
+
model=base_model,
|
430 |
+
device=0 if device == "cuda" else -1,
|
431 |
+
torch_dtype=torch.float16)
|
432 |
+
else:
|
433 |
+
assert device == "cuda", "Unsupported device %s" % device
|
434 |
+
model_kwargs = dict(local_files_only=local_files_only,
|
435 |
+
torch_dtype=torch.float16,
|
436 |
+
resume_download=resume_download,
|
437 |
+
use_auth_token=use_auth_token)
|
438 |
+
if 'mbart-' not in base_model.lower():
|
439 |
+
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
440 |
+
device_map={"": 0} if load_8bit else "auto",
|
441 |
+
))
|
442 |
+
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
443 |
+
# could put on other GPUs
|
444 |
+
model_kwargs['device_map'] = {"": 0}
|
445 |
+
model_kwargs.pop('torch_dtype', None)
|
446 |
+
|
447 |
+
if not lora_weights:
|
448 |
+
with torch.device("cuda"):
|
449 |
+
if infer_devices:
|
450 |
+
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
451 |
+
gpu_id=gpu_id, use_auth_token=use_auth_token)
|
452 |
+
else:
|
453 |
+
if load_half and not load_8bit:
|
454 |
+
model = model_loader.from_pretrained(
|
455 |
+
base_model,
|
456 |
+
**model_kwargs).half()
|
457 |
+
else:
|
458 |
+
model = model_loader.from_pretrained(
|
459 |
+
base_model,
|
460 |
+
**model_kwargs)
|
461 |
+
elif load_8bit:
|
462 |
+
model = model_loader.from_pretrained(
|
463 |
+
base_model,
|
464 |
+
**model_kwargs
|
465 |
+
)
|
466 |
+
model = PeftModel.from_pretrained(
|
467 |
+
model,
|
468 |
+
lora_weights,
|
469 |
+
torch_dtype=torch.float16,
|
470 |
+
local_files_only=local_files_only,
|
471 |
+
resume_download=resume_download,
|
472 |
+
use_auth_token=use_auth_token,
|
473 |
+
device_map={"": 0}, # seems to be required
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
with torch.device("cuda"):
|
477 |
+
model = model_loader.from_pretrained(
|
478 |
+
base_model,
|
479 |
+
**model_kwargs
|
480 |
+
)
|
481 |
+
model = PeftModel.from_pretrained(
|
482 |
+
model,
|
483 |
+
lora_weights,
|
484 |
+
torch_dtype=torch.float16,
|
485 |
+
local_files_only=local_files_only,
|
486 |
+
resume_download=resume_download,
|
487 |
+
use_auth_token=use_auth_token,
|
488 |
+
device_map="auto",
|
489 |
+
)
|
490 |
+
if load_half:
|
491 |
+
model.half()
|
492 |
+
|
493 |
+
# unwind broken decapoda-research config
|
494 |
+
if llama_type:
|
495 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
496 |
+
model.config.bos_token_id = 1
|
497 |
+
model.config.eos_token_id = 2
|
498 |
+
if 'gpt2' in base_model.lower():
|
499 |
+
# add special tokens that otherwise all share the same id
|
500 |
+
tokenizer.add_special_tokens({'bos_token': '<bos>',
|
501 |
+
'eos_token': '<eos>',
|
502 |
+
'pad_token': '<pad>'})
|
503 |
+
|
504 |
+
if not isinstance(tokenizer, str):
|
505 |
+
model.eval()
|
506 |
+
if torch.__version__ >= "2" and sys.platform != "win32" and compile:
|
507 |
+
model = torch.compile(model)
|
508 |
+
|
509 |
+
return model, tokenizer, device
|
510 |
+
|
511 |
+
|
512 |
+
def get_score_model(**kwargs):
|
513 |
+
# score model
|
514 |
+
if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
|
515 |
+
score_all_kwargs = kwargs.copy()
|
516 |
+
score_all_kwargs['load_8bit'] = False
|
517 |
+
score_all_kwargs['load_half'] = False
|
518 |
+
score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
|
519 |
+
score_all_kwargs['tokenizer_base_model'] = ''
|
520 |
+
score_all_kwargs['lora_weights'] = ''
|
521 |
+
score_all_kwargs['llama_type'] = False
|
522 |
+
score_all_kwargs['compile'] = False
|
523 |
+
smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
|
524 |
+
else:
|
525 |
+
smodel, stokenizer, sdevice = None, None, None
|
526 |
+
return smodel, stokenizer, sdevice
|
527 |
+
|
528 |
+
|
529 |
+
def go_gradio(**kwargs):
|
530 |
+
# get default model
|
531 |
+
all_kwargs = kwargs.copy()
|
532 |
+
all_kwargs.update(locals())
|
533 |
+
if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
|
534 |
+
model0, tokenizer0, device = get_model(**all_kwargs)
|
535 |
+
else:
|
536 |
+
# if empty model, then don't load anything, just get gradio up
|
537 |
+
model0, tokenizer0, device = None, None, None
|
538 |
+
model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
|
539 |
+
|
540 |
+
# get score model
|
541 |
+
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
542 |
+
|
543 |
+
if 'mbart-' in kwargs['model_lower']:
|
544 |
+
instruction_label_nochat = "Text to translate"
|
545 |
+
else:
|
546 |
+
instruction_label_nochat = "Instruction"
|
547 |
+
instruction_label = "You (Shift-Enter or push Submit to send message)"
|
548 |
+
|
549 |
+
title = 'h2oGPT'
|
550 |
+
if kwargs['verbose']:
|
551 |
+
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
552 |
+
For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
|
553 |
+
Command: {str(' '.join(sys.argv))}
|
554 |
+
Hash: {get_githash()}
|
555 |
+
"""
|
556 |
+
else:
|
557 |
+
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
558 |
+
if is_public:
|
559 |
+
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
|
560 |
+
if kwargs['load_8bit']:
|
561 |
+
description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
|
562 |
+
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
563 |
+
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
564 |
+
|
565 |
+
if kwargs['verbose']:
|
566 |
+
task_info_md = f"""
|
567 |
+
### Task: {kwargs['task_info']}"""
|
568 |
+
else:
|
569 |
+
task_info_md = ''
|
570 |
+
|
571 |
+
css_code = """footer {visibility: hidden;}
|
572 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
573 |
+
body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
574 |
+
|
575 |
+
from gradio.themes.utils import Color, colors, fonts, sizes
|
576 |
+
if kwargs['h2ocolors']:
|
577 |
+
h2o_yellow = Color(
|
578 |
+
name="yellow",
|
579 |
+
c50="#fffef2",
|
580 |
+
c100="#fff9e6",
|
581 |
+
c200="#ffecb3",
|
582 |
+
c300="#ffe28c",
|
583 |
+
c400="#ffd659",
|
584 |
+
c500="#fec925",
|
585 |
+
c600="#e6ac00",
|
586 |
+
c700="#bf8f00",
|
587 |
+
c800="#a67c00",
|
588 |
+
c900="#664d00",
|
589 |
+
c950="#403000",
|
590 |
+
)
|
591 |
+
h2o_gray = Color(
|
592 |
+
name="gray",
|
593 |
+
c50="#f2f2f2",
|
594 |
+
c100="#e5e5e5",
|
595 |
+
c200="#cccccc",
|
596 |
+
c300="#b2b2b2",
|
597 |
+
c400="#999999",
|
598 |
+
c500="#7f7f7f",
|
599 |
+
c600="#666666",
|
600 |
+
c700="#4c4c4c",
|
601 |
+
c800="#333333",
|
602 |
+
c900="#191919",
|
603 |
+
c950="#0d0d0d",
|
604 |
+
)
|
605 |
+
colors_dict = dict(primary_hue=h2o_yellow,
|
606 |
+
secondary_hue=h2o_yellow,
|
607 |
+
neutral_hue=h2o_gray,
|
608 |
+
spacing_size=sizes.spacing_md,
|
609 |
+
radius_size=sizes.radius_md,
|
610 |
+
text_size=sizes.text_md,
|
611 |
+
)
|
612 |
+
else:
|
613 |
+
colors_dict = dict(primary_hue=colors.indigo,
|
614 |
+
secondary_hue=colors.indigo,
|
615 |
+
neutral_hue=colors.gray,
|
616 |
+
spacing_size=sizes.spacing_md,
|
617 |
+
radius_size=sizes.radius_md,
|
618 |
+
text_size=sizes.text_md,
|
619 |
+
)
|
620 |
+
|
621 |
+
import gradio as gr
|
622 |
+
|
623 |
+
if kwargs['gradio_avoid_processing_markdown']:
|
624 |
+
from gradio_client import utils as client_utils
|
625 |
+
from gradio.components import Chatbot
|
626 |
+
|
627 |
+
# gradio has issue with taking too long to process input/output for markdown etc.
|
628 |
+
# Avoid for now, allow raw html to render, good enough for chatbot.
|
629 |
+
def _postprocess_chat_messages(self, chat_message: str):
|
630 |
+
if chat_message is None:
|
631 |
+
return None
|
632 |
+
elif isinstance(chat_message, (tuple, list)):
|
633 |
+
filepath = chat_message[0]
|
634 |
+
mime_type = client_utils.get_mimetype(filepath)
|
635 |
+
filepath = self.make_temp_copy_if_needed(filepath)
|
636 |
+
return {
|
637 |
+
"name": filepath,
|
638 |
+
"mime_type": mime_type,
|
639 |
+
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
640 |
+
"data": None, # These last two fields are filled in by the frontend
|
641 |
+
"is_file": True,
|
642 |
+
}
|
643 |
+
elif isinstance(chat_message, str):
|
644 |
+
return chat_message
|
645 |
+
else:
|
646 |
+
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
647 |
+
|
648 |
+
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
649 |
+
|
650 |
+
demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
|
651 |
+
callback = gr.CSVLogger()
|
652 |
+
# css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
|
653 |
+
# demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
|
654 |
+
|
655 |
+
model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
|
656 |
+
if kwargs['base_model'].strip() not in model_options:
|
657 |
+
lora_options = [kwargs['base_model'].strip()] + model_options
|
658 |
+
lora_options = kwargs['extra_lora_options']
|
659 |
+
if kwargs['lora_weights'].strip() not in lora_options:
|
660 |
+
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
661 |
+
# always add in no lora case
|
662 |
+
# add fake space so doesn't go away in gradio dropdown
|
663 |
+
no_lora_str = no_model_str = '[None/Remove]'
|
664 |
+
lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
|
665 |
+
# always add in no model case so can free memory
|
666 |
+
# add fake space so doesn't go away in gradio dropdown
|
667 |
+
model_options = [no_model_str] + model_options
|
668 |
+
|
669 |
+
# transcribe, will be detranscribed before use by evaluate()
|
670 |
+
if not kwargs['lora_weights'].strip():
|
671 |
+
kwargs['lora_weights'] = no_lora_str
|
672 |
+
|
673 |
+
if not kwargs['base_model'].strip():
|
674 |
+
kwargs['base_model'] = no_model_str
|
675 |
+
|
676 |
+
# transcribe for gradio
|
677 |
+
kwargs['gpu_id'] = str(kwargs['gpu_id'])
|
678 |
+
|
679 |
+
no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
|
680 |
+
output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
|
681 |
+
'base_model') else no_model_msg
|
682 |
+
output_label0_model2 = no_model_msg
|
683 |
+
|
684 |
+
with demo:
|
685 |
+
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
686 |
+
# https://github.com/gradio-app/gradio/issues/3558
|
687 |
+
model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
|
688 |
+
model_state2 = gr.State([None, None, None, None])
|
689 |
+
model_options_state = gr.State([model_options])
|
690 |
+
lora_options_state = gr.State([lora_options])
|
691 |
+
gr.Markdown(
|
692 |
+
f"""
|
693 |
+
<h1 align="center"> {title}</h1>
|
694 |
+
|
695 |
+
{description}
|
696 |
+
{task_info_md}
|
697 |
+
""")
|
698 |
+
if is_hf:
|
699 |
+
gr.HTML(
|
700 |
+
'''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
|
701 |
+
|
702 |
+
# go button visible if
|
703 |
+
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
704 |
+
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
705 |
+
normal_block = gr.Row(visible=not base_wanted)
|
706 |
+
with normal_block:
|
707 |
+
with gr.Tabs():
|
708 |
+
with gr.Row():
|
709 |
+
col_nochat = gr.Column(visible=not kwargs['chat'])
|
710 |
+
with col_nochat: # FIXME: for model comparison, and check rest
|
711 |
+
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
|
712 |
+
instruction_nochat = gr.Textbox(
|
713 |
+
lines=4, label=instruction_label_nochat,
|
714 |
+
placeholder=kwargs['placeholder_instruction'],
|
715 |
+
)
|
716 |
+
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
|
717 |
+
placeholder=kwargs['placeholder_input'])
|
718 |
+
submit_nochat = gr.Button("Submit")
|
719 |
+
flag_btn_nochat = gr.Button("Flag")
|
720 |
+
if kwargs['score_model']:
|
721 |
+
if not kwargs['auto_score']:
|
722 |
+
with gr.Column():
|
723 |
+
score_btn_nochat = gr.Button("Score last prompt & response")
|
724 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
725 |
+
else:
|
726 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
727 |
+
col_chat = gr.Column(visible=kwargs['chat'])
|
728 |
+
with col_chat:
|
729 |
+
with gr.Row():
|
730 |
+
text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
|
731 |
+
text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
|
732 |
+
height=kwargs['height'] or 400)
|
733 |
+
with gr.Row():
|
734 |
+
with gr.Column(scale=50):
|
735 |
+
instruction = gr.Textbox(
|
736 |
+
lines=4, label=instruction_label,
|
737 |
+
placeholder=kwargs['placeholder_instruction'],
|
738 |
+
)
|
739 |
+
with gr.Row(): # .style(equal_height=False, equal_width=False):
|
740 |
+
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
741 |
+
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
742 |
+
with gr.Row():
|
743 |
+
clear = gr.Button("New Conversation")
|
744 |
+
flag_btn = gr.Button("Flag")
|
745 |
+
if kwargs['score_model']:
|
746 |
+
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
747 |
+
with gr.Column():
|
748 |
+
with gr.Row():
|
749 |
+
score_btn = gr.Button("Score last prompt & response").style(
|
750 |
+
full_width=False, size='sm')
|
751 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
752 |
+
score_res2 = gr.Row(visible=False)
|
753 |
+
with score_res2:
|
754 |
+
score_btn2 = gr.Button("Score last prompt & response 2").style(
|
755 |
+
full_width=False, size='sm')
|
756 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
|
757 |
+
else:
|
758 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
759 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
|
760 |
+
retry = gr.Button("Regenerate")
|
761 |
+
undo = gr.Button("Undo")
|
762 |
+
with gr.TabItem("Input/Output"):
|
763 |
+
with gr.Row():
|
764 |
+
if 'mbart-' in kwargs['model_lower']:
|
765 |
+
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
766 |
+
value=kwargs['src_lang'],
|
767 |
+
label="Input Language")
|
768 |
+
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
769 |
+
value=kwargs['tgt_lang'],
|
770 |
+
label="Output Language")
|
771 |
+
with gr.TabItem("Expert"):
|
772 |
+
with gr.Row():
|
773 |
+
with gr.Column():
|
774 |
+
stream_output = gr.components.Checkbox(label="Stream output",
|
775 |
+
value=kwargs['stream_output'])
|
776 |
+
prompt_type = gr.Dropdown(prompt_types_strings,
|
777 |
+
value=kwargs['prompt_type'], label="Prompt Type",
|
778 |
+
visible=not is_public)
|
779 |
+
prompt_type2 = gr.Dropdown(prompt_types_strings,
|
780 |
+
value=kwargs['prompt_type'], label="Prompt Type Model 2",
|
781 |
+
visible=not is_public and False)
|
782 |
+
do_sample = gr.Checkbox(label="Sample", info="Enable sampler, required for use of temperature, top_p, top_k",
|
783 |
+
value=kwargs['do_sample'])
|
784 |
+
temperature = gr.Slider(minimum=0.01, maximum=3,
|
785 |
+
value=kwargs['temperature'],
|
786 |
+
label="Temperature",
|
787 |
+
info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
|
788 |
+
top_p = gr.Slider(minimum=0, maximum=1,
|
789 |
+
value=kwargs['top_p'], label="Top p",
|
790 |
+
info="Cumulative probability of tokens to sample from")
|
791 |
+
top_k = gr.Slider(
|
792 |
+
minimum=0, maximum=100, step=1,
|
793 |
+
value=kwargs['top_k'], label="Top k",
|
794 |
+
info='Num. tokens to sample from'
|
795 |
+
)
|
796 |
+
max_beams = 8 if not is_low_mem else 2
|
797 |
+
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
798 |
+
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
799 |
+
info="Number of searches for optimal overall probability. "
|
800 |
+
"Uses more GPU memory/compute")
|
801 |
+
max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
|
802 |
+
max_new_tokens = gr.Slider(
|
803 |
+
minimum=1, maximum=max_max_new_tokens, step=1,
|
804 |
+
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
805 |
+
)
|
806 |
+
min_new_tokens = gr.Slider(
|
807 |
+
minimum=0, maximum=max_max_new_tokens, step=1,
|
808 |
+
value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
|
809 |
+
)
|
810 |
+
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
811 |
+
value=kwargs['early_stopping'])
|
812 |
+
max_max_time = 60 * 5 if not is_low_mem else 60
|
813 |
+
max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
|
814 |
+
value=min(max_max_time, kwargs['max_time']), label="Max. time",
|
815 |
+
info="Max. time to search optimal output.")
|
816 |
+
repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
|
817 |
+
value=kwargs['repetition_penalty'],
|
818 |
+
label="Repetition Penalty")
|
819 |
+
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
|
820 |
+
value=kwargs['num_return_sequences'],
|
821 |
+
label="Number Returns", info="Must be <= num_beams",
|
822 |
+
visible=not is_public)
|
823 |
+
iinput = gr.Textbox(lines=4, label="Input",
|
824 |
+
placeholder=kwargs['placeholder_input'],
|
825 |
+
visible=not is_public)
|
826 |
+
context = gr.Textbox(lines=3, label="System Pre-Context",
|
827 |
+
info="Directly pre-appended without prompt processing",
|
828 |
+
visible=not is_public and not kwargs['chat'])
|
829 |
+
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
830 |
+
visible=not is_public)
|
831 |
+
|
832 |
+
with gr.TabItem("Models"):
|
833 |
+
load_msg = "Load-Unload Model/LORA" if not is_public \
|
834 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
|
835 |
+
load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
|
836 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
|
837 |
+
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
|
838 |
+
value=False, visible=not is_public)
|
839 |
+
with gr.Row():
|
840 |
+
n_gpus = torch.cuda.device_count()
|
841 |
+
n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
|
842 |
+
with gr.Column():
|
843 |
+
with gr.Row(scale=1):
|
844 |
+
with gr.Column(scale=50):
|
845 |
+
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
|
846 |
+
value=kwargs['base_model'])
|
847 |
+
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
|
848 |
+
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
849 |
+
with gr.Column(scale=1):
|
850 |
+
load_model_button = gr.Button(load_msg)
|
851 |
+
model_load8bit_checkbox = gr.components.Checkbox(
|
852 |
+
label="Load 8-bit [Not all models support]",
|
853 |
+
value=kwargs['load_8bit'])
|
854 |
+
model_infer_devices_checkbox = gr.components.Checkbox(
|
855 |
+
label="Infer Devices [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
|
856 |
+
value=kwargs['infer_devices'])
|
857 |
+
model_gpu = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
|
858 |
+
value=kwargs['gpu_id'])
|
859 |
+
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
860 |
+
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
861 |
+
visible=kwargs['show_lora'])
|
862 |
+
with gr.Row(scale=1):
|
863 |
+
with gr.Column(scale=50):
|
864 |
+
new_model = gr.Textbox(label="New Model HF name/path")
|
865 |
+
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
|
866 |
+
with gr.Column(scale=1):
|
867 |
+
add_model_button = gr.Button("Add new model name")
|
868 |
+
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
869 |
+
col_model2 = gr.Column(visible=False)
|
870 |
+
with col_model2:
|
871 |
+
with gr.Row(scale=1):
|
872 |
+
with gr.Column(scale=50):
|
873 |
+
model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
|
874 |
+
value=no_model_str)
|
875 |
+
lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
|
876 |
+
value=no_lora_str,
|
877 |
+
visible=kwargs['show_lora'])
|
878 |
+
with gr.Column(scale=1):
|
879 |
+
load_model_button2 = gr.Button(load_msg2)
|
880 |
+
model_load8bit_checkbox2 = gr.components.Checkbox(
|
881 |
+
label="Load 8-bit 2 [Not all models support]",
|
882 |
+
value=kwargs['load_8bit'])
|
883 |
+
model_infer_devices_checkbox2 = gr.components.Checkbox(
|
884 |
+
label="Infer Devices 2 [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
|
885 |
+
value=kwargs[
|
886 |
+
'infer_devices'])
|
887 |
+
model_gpu2 = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
|
888 |
+
value=kwargs['gpu_id'])
|
889 |
+
# no model/lora loaded ever in model2 by default
|
890 |
+
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
|
891 |
+
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
|
892 |
+
visible=kwargs['show_lora'])
|
893 |
+
with gr.TabItem("System"):
|
894 |
+
system_row = gr.Row(visible=not is_public)
|
895 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
896 |
+
admin_btn = gr.Button(value="admin", visible=is_public)
|
897 |
+
with system_row:
|
898 |
+
with gr.Column():
|
899 |
+
system_text = gr.Textbox(label='System Info')
|
900 |
+
system_btn = gr.Button(value='Get System Info')
|
901 |
+
|
902 |
+
zip_btn = gr.Button("Zip")
|
903 |
+
file_output = gr.File()
|
904 |
+
|
905 |
+
# Get flagged data
|
906 |
+
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
907 |
+
zip_btn.click(zip_data1, inputs=None, outputs=file_output)
|
908 |
+
|
909 |
+
def check_admin_pass(x):
|
910 |
+
return gr.update(visible=x == admin_pass)
|
911 |
+
|
912 |
+
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row)
|
913 |
+
|
914 |
+
# Get inputs to evaluate()
|
915 |
+
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
916 |
+
from functools import partial
|
917 |
+
all_kwargs = kwargs.copy()
|
918 |
+
all_kwargs.update(locals())
|
919 |
+
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
920 |
+
fun = partial(evaluate,
|
921 |
+
**kwargs_evaluate)
|
922 |
+
fun2 = partial(evaluate,
|
923 |
+
model_state2,
|
924 |
+
**kwargs_evaluate)
|
925 |
+
|
926 |
+
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
|
927 |
+
size="sm",
|
928 |
+
)
|
929 |
+
dark_mode_btn.click(
|
930 |
+
None,
|
931 |
+
None,
|
932 |
+
None,
|
933 |
+
_js="""() => {
|
934 |
+
if (document.querySelectorAll('.dark').length) {
|
935 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
936 |
+
} else {
|
937 |
+
document.querySelector('body').classList.add('dark');
|
938 |
+
}
|
939 |
+
}""",
|
940 |
+
api_name="dark",
|
941 |
+
)
|
942 |
+
|
943 |
+
# Control chat and non-chat blocks, which can be independently used by chat checkbox swap
|
944 |
+
def col_nochat_fun(x):
|
945 |
+
return gr.Column.update(visible=not x)
|
946 |
+
|
947 |
+
def col_chat_fun(x):
|
948 |
+
return gr.Column.update(visible=x)
|
949 |
+
|
950 |
+
def context_fun(x):
|
951 |
+
return gr.Textbox.update(visible=not x)
|
952 |
+
|
953 |
+
chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox") \
|
954 |
+
.then(col_chat_fun, chat, col_chat) \
|
955 |
+
.then(context_fun, chat, context)
|
956 |
+
|
957 |
+
# examples after submit or any other buttons for chat or no chat
|
958 |
+
if kwargs['examples'] is not None and kwargs['show_examples']:
|
959 |
+
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
|
960 |
+
|
961 |
+
# Score
|
962 |
+
def score_last_response(*args, nochat=False, model2=False):
|
963 |
+
""" Similar to user() """
|
964 |
+
args_list = list(args)
|
965 |
+
|
966 |
+
max_length_tokenize = 512 if is_low_mem else 2048
|
967 |
+
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
|
968 |
+
|
969 |
+
if not nochat:
|
970 |
+
history = args_list[-1]
|
971 |
+
if history is None:
|
972 |
+
if not model2:
|
973 |
+
# maybe only doing first model, no need to complain
|
974 |
+
print("Bad history in scoring last response, fix for now", flush=True)
|
975 |
+
history = []
|
976 |
+
if smodel is not None and \
|
977 |
+
stokenizer is not None and \
|
978 |
+
sdevice is not None and \
|
979 |
+
history is not None and len(history) > 0 and \
|
980 |
+
history[-1] is not None and \
|
981 |
+
len(history[-1]) >= 2:
|
982 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
983 |
+
|
984 |
+
question = history[-1][0]
|
985 |
+
|
986 |
+
answer = history[-1][1]
|
987 |
+
else:
|
988 |
+
return 'Response Score: NA'
|
989 |
+
else:
|
990 |
+
answer = args_list[-1]
|
991 |
+
instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
|
992 |
+
question = args_list[instruction_nochat_arg_id]
|
993 |
+
|
994 |
+
if question is None:
|
995 |
+
return 'Response Score: Bad Question'
|
996 |
+
if answer is None:
|
997 |
+
return 'Response Score: Bad Answer'
|
998 |
+
|
999 |
+
question = question[-cutoff_len:]
|
1000 |
+
answer = answer[-cutoff_len:]
|
1001 |
+
|
1002 |
+
inputs = stokenizer(question, answer,
|
1003 |
+
return_tensors="pt",
|
1004 |
+
truncation=True,
|
1005 |
+
max_length=max_length_tokenize).to(smodel.device)
|
1006 |
+
try:
|
1007 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
1008 |
+
except torch.cuda.OutOfMemoryError as e:
|
1009 |
+
print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
1010 |
+
del inputs
|
1011 |
+
traceback.print_exc()
|
1012 |
+
clear_torch_cache()
|
1013 |
+
return 'Response Score: GPU OOM'
|
1014 |
+
except (Exception, RuntimeError) as e:
|
1015 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
1016 |
+
'expected scalar type Half but found Float' in str(e) or \
|
1017 |
+
'probability tensor contains either' in str(e) or \
|
1018 |
+
'cublasLt ran into an error!' in str(e):
|
1019 |
+
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
|
1020 |
+
flush=True)
|
1021 |
+
traceback.print_exc()
|
1022 |
+
clear_torch_cache()
|
1023 |
+
return 'Response Score: GPU Error'
|
1024 |
+
else:
|
1025 |
+
raise
|
1026 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
1027 |
+
return 'Response Score: {:.1%}'.format(score)
|
1028 |
+
|
1029 |
+
if kwargs['score_model']:
|
1030 |
+
score_args = dict(fn=score_last_response,
|
1031 |
+
inputs=inputs_list + [text_output],
|
1032 |
+
outputs=[score_text],
|
1033 |
+
)
|
1034 |
+
score_args2 = dict(fn=partial(score_last_response, model2=True),
|
1035 |
+
inputs=inputs_list + [text_output2],
|
1036 |
+
outputs=[score_text2],
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
score_args_nochat = dict(fn=partial(score_last_response, nochat=True),
|
1040 |
+
inputs=inputs_list + [text_output_nochat],
|
1041 |
+
outputs=[score_text_nochat],
|
1042 |
+
)
|
1043 |
+
if not kwargs['auto_score']:
|
1044 |
+
score_event = score_btn.click(**score_args, queue=stream_output, api_name='score') \
|
1045 |
+
.then(**score_args2, queue=stream_output, api_name='score2')
|
1046 |
+
score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=stream_output,
|
1047 |
+
api_name='score_nochat')
|
1048 |
+
|
1049 |
+
def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
|
1050 |
+
"""
|
1051 |
+
User that fills history for bot
|
1052 |
+
:param args:
|
1053 |
+
:param undo:
|
1054 |
+
:param sanitize_user_prompt:
|
1055 |
+
:param model2:
|
1056 |
+
:return:
|
1057 |
+
"""
|
1058 |
+
args_list = list(args)
|
1059 |
+
user_message = args_list[0]
|
1060 |
+
input1 = args_list[1]
|
1061 |
+
context1 = args_list[2]
|
1062 |
+
if input1 and not user_message.endswith(':'):
|
1063 |
+
user_message1 = user_message + ":" + input1
|
1064 |
+
elif input1:
|
1065 |
+
user_message1 = user_message + input1
|
1066 |
+
else:
|
1067 |
+
user_message1 = user_message
|
1068 |
+
if sanitize_user_prompt:
|
1069 |
+
from better_profanity import profanity
|
1070 |
+
user_message1 = profanity.censor(user_message1)
|
1071 |
+
|
1072 |
+
history = args_list[-1]
|
1073 |
+
if undo and history:
|
1074 |
+
history.pop()
|
1075 |
+
args_list = args_list[:-1] # FYI, even if unused currently
|
1076 |
+
if history is None:
|
1077 |
+
if not model2:
|
1078 |
+
# no need to complain so often unless model1
|
1079 |
+
print("Bad history, fix for now", flush=True)
|
1080 |
+
history = []
|
1081 |
+
# ensure elements not mixed across models as output,
|
1082 |
+
# even if input is currently same source
|
1083 |
+
history = history.copy()
|
1084 |
+
if undo:
|
1085 |
+
return history
|
1086 |
+
else:
|
1087 |
+
# FIXME: compare, same history for now
|
1088 |
+
return history + [[user_message1, None]]
|
1089 |
+
|
1090 |
+
def bot(*args, retry=False):
|
1091 |
+
"""
|
1092 |
+
bot that consumes history for user input
|
1093 |
+
instruction (from input_list) itself is not consumed by bot
|
1094 |
+
:param args:
|
1095 |
+
:param retry:
|
1096 |
+
:return:
|
1097 |
+
"""
|
1098 |
+
args_list = list(args).copy()
|
1099 |
+
history = args_list[-1] # model_state is -2
|
1100 |
+
if retry and history:
|
1101 |
+
history.pop()
|
1102 |
+
if not history:
|
1103 |
+
print("No history", flush=True)
|
1104 |
+
return
|
1105 |
+
# ensure output will be unique to models
|
1106 |
+
history = history.copy()
|
1107 |
+
instruction1 = history[-1][0]
|
1108 |
+
context1 = ''
|
1109 |
+
if kwargs['chat_history'] > 0:
|
1110 |
+
prompt_type_arg_id = eval_func_param_names.index('prompt_type')
|
1111 |
+
prompt_type1 = args_list[prompt_type_arg_id]
|
1112 |
+
chat_arg_id = eval_func_param_names.index('chat')
|
1113 |
+
chat1 = args_list[chat_arg_id]
|
1114 |
+
context1 = ''
|
1115 |
+
for histi in range(len(history) - 1):
|
1116 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
1117 |
+
context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
|
1118 |
+
'<br>', '\n')
|
1119 |
+
if not context1.endswith('\n'):
|
1120 |
+
context1 += '\n'
|
1121 |
+
if context1 and not context1.endswith('\n'):
|
1122 |
+
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
|
1123 |
+
args_list[0] = instruction1 # override original instruction with history from user
|
1124 |
+
# only include desired chat history
|
1125 |
+
args_list[2] = context1[-kwargs['chat_history']:]
|
1126 |
+
model_state1 = args_list[-2]
|
1127 |
+
if model_state1[0] is None or model_state1[0] == no_model_str:
|
1128 |
+
return
|
1129 |
+
args_list = args_list[:-2]
|
1130 |
+
fun1 = partial(evaluate,
|
1131 |
+
model_state1,
|
1132 |
+
**kwargs_evaluate)
|
1133 |
+
try:
|
1134 |
+
for output in fun1(*tuple(args_list)):
|
1135 |
+
bot_message = output
|
1136 |
+
history[-1][1] = bot_message
|
1137 |
+
yield history
|
1138 |
+
except StopIteration:
|
1139 |
+
yield history
|
1140 |
+
except RuntimeError as e:
|
1141 |
+
if "generator raised StopIteration" in str(e):
|
1142 |
+
# assume last entry was bad, undo
|
1143 |
+
history.pop()
|
1144 |
+
yield history
|
1145 |
+
raise
|
1146 |
+
except Exception as e:
|
1147 |
+
# put error into user input
|
1148 |
+
history[-1][0] = "Exception: %s" % str(e)
|
1149 |
+
yield history
|
1150 |
+
raise
|
1151 |
+
return
|
1152 |
+
|
1153 |
+
# NORMAL MODEL
|
1154 |
+
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
1155 |
+
inputs=inputs_list + [text_output],
|
1156 |
+
outputs=text_output,
|
1157 |
+
)
|
1158 |
+
bot_args = dict(fn=bot,
|
1159 |
+
inputs=inputs_list + [model_state] + [text_output],
|
1160 |
+
outputs=text_output,
|
1161 |
+
)
|
1162 |
+
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1163 |
+
inputs=inputs_list + [model_state] + [text_output],
|
1164 |
+
outputs=text_output,
|
1165 |
+
)
|
1166 |
+
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
1167 |
+
inputs=inputs_list + [text_output],
|
1168 |
+
outputs=text_output,
|
1169 |
+
)
|
1170 |
+
|
1171 |
+
# MODEL2
|
1172 |
+
user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
|
1173 |
+
inputs=inputs_list + [text_output2],
|
1174 |
+
outputs=text_output2,
|
1175 |
+
)
|
1176 |
+
bot_args2 = dict(fn=bot,
|
1177 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
1178 |
+
outputs=text_output2,
|
1179 |
+
)
|
1180 |
+
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1181 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
1182 |
+
outputs=text_output2,
|
1183 |
+
)
|
1184 |
+
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
|
1185 |
+
inputs=inputs_list + [text_output2],
|
1186 |
+
outputs=text_output2,
|
1187 |
+
)
|
1188 |
+
|
1189 |
+
def clear_instruct():
|
1190 |
+
return gr.Textbox.update(value='')
|
1191 |
+
|
1192 |
+
if kwargs['auto_score']:
|
1193 |
+
# in case 2nd model, consume instruction first, so can clear quickly
|
1194 |
+
# bot doesn't consume instruction itself, just history from user, so why works
|
1195 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
|
1196 |
+
.then(**user_args2, queue=stream_output, api_name='instruction2') \
|
1197 |
+
.then(clear_instruct, None, instruction) \
|
1198 |
+
.then(**bot_args, api_name='instruction_bot') \
|
1199 |
+
.then(**score_args, api_name='instruction_bot_score') \
|
1200 |
+
.then(**bot_args2, api_name='instruction_bot2') \
|
1201 |
+
.then(**score_args2, api_name='instruction_bot_score2') \
|
1202 |
+
.then(clear_torch_cache)
|
1203 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
|
1204 |
+
.then(**user_args2, queue=stream_output, api_name='submit2') \
|
1205 |
+
.then(**bot_args, api_name='submit_bot') \
|
1206 |
+
.then(clear_instruct, None, instruction) \
|
1207 |
+
.then(**score_args, api_name='submit_bot_score') \
|
1208 |
+
.then(**bot_args2, api_name='submit_bot2') \
|
1209 |
+
.then(**score_args2, api_name='submit_bot_score2') \
|
1210 |
+
.then(clear_torch_cache)
|
1211 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
|
1212 |
+
.then(**user_args2, queue=stream_output, api_name='retry2') \
|
1213 |
+
.then(clear_instruct, None, instruction) \
|
1214 |
+
.then(**retry_bot_args, api_name='retry_bot') \
|
1215 |
+
.then(**score_args, api_name='retry_bot_score') \
|
1216 |
+
.then(**retry_bot_args2, api_name='retry_bot2') \
|
1217 |
+
.then(**score_args2, api_name='retry_bot_score2') \
|
1218 |
+
.then(clear_torch_cache)
|
1219 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
|
1220 |
+
.then(**score_args, api_name='undo_score') \
|
1221 |
+
.then(**undo_user_args2, queue=stream_output, api_name='undo2') \
|
1222 |
+
.then(**score_args2, api_name='undo_score2') \
|
1223 |
+
.then(clear_instruct, None, instruction)
|
1224 |
+
else:
|
1225 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
|
1226 |
+
.then(**user_args2, queue=stream_output, api_name='instruction2') \
|
1227 |
+
.then(clear_instruct, None, instruction) \
|
1228 |
+
.then(**bot_args, api_name='instruction_bot') \
|
1229 |
+
.then(**bot_args2, api_name='instruction_bot2') \
|
1230 |
+
.then(clear_torch_cache)
|
1231 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
|
1232 |
+
.then(**user_args2, queue=stream_output, api_name='submit2') \
|
1233 |
+
.then(clear_instruct, None, instruction) \
|
1234 |
+
.then(**bot_args, api_name='submit_bot') \
|
1235 |
+
.then(**bot_args2, api_name='submit_bot2') \
|
1236 |
+
.then(clear_torch_cache)
|
1237 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
|
1238 |
+
.then(**user_args2, queue=stream_output, api_name='retry2') \
|
1239 |
+
.then(clear_instruct, None, instruction) \
|
1240 |
+
.then(**retry_bot_args, api_name='retry_bot') \
|
1241 |
+
.then(**retry_bot_args2, api_name='retry_bot2') \
|
1242 |
+
.then(clear_torch_cache)
|
1243 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
|
1244 |
+
.then(**undo_user_args2, queue=stream_output, api_name='undo2')
|
1245 |
+
|
1246 |
+
# does both models
|
1247 |
+
clear.click(lambda: None, None, text_output, queue=False, api_name='clear') \
|
1248 |
+
.then(lambda: None, None, text_output2, queue=False, api_name='clear2')
|
1249 |
+
# FIXME: compare
|
1250 |
+
submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
|
1251 |
+
outputs=text_output_nochat, api_name='submit_nochat') \
|
1252 |
+
.then(**score_args_nochat, api_name='instruction_bot_score_nochat') \
|
1253 |
+
.then(clear_torch_cache)
|
1254 |
+
|
1255 |
+
def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
|
1256 |
+
# ensure old model removed from GPU memory
|
1257 |
+
if kwargs['debug']:
|
1258 |
+
print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1259 |
+
|
1260 |
+
if isinstance(model_state_old[0], str) and model0 is not None:
|
1261 |
+
# best can do, move model loaded at first to CPU
|
1262 |
+
model0.cpu()
|
1263 |
+
|
1264 |
+
if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
|
1265 |
+
try:
|
1266 |
+
model_state_old[0].cpu()
|
1267 |
+
except Exception as e:
|
1268 |
+
# sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
|
1269 |
+
print("Unable to put model on CPU: %s" % str(e), flush=True)
|
1270 |
+
del model_state_old[0]
|
1271 |
+
model_state_old[0] = None
|
1272 |
+
|
1273 |
+
if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
|
1274 |
+
del model_state_old[1]
|
1275 |
+
model_state_old[1] = None
|
1276 |
+
|
1277 |
+
clear_torch_cache()
|
1278 |
+
if kwargs['debug']:
|
1279 |
+
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1280 |
+
|
1281 |
+
if model_name is None or model_name == no_model_str:
|
1282 |
+
# no-op if no model, just free memory
|
1283 |
+
# no detranscribe needed for model, never go into evaluate
|
1284 |
+
lora_weights = no_lora_str
|
1285 |
+
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
|
1286 |
+
|
1287 |
+
all_kwargs1 = all_kwargs.copy()
|
1288 |
+
all_kwargs1['base_model'] = model_name.strip()
|
1289 |
+
all_kwargs1['load_8bit'] = load_8bit
|
1290 |
+
all_kwargs1['infer_devices'] = infer_devices
|
1291 |
+
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
|
1292 |
+
model_lower = model_name.strip().lower()
|
1293 |
+
if model_lower in inv_prompt_type_to_model_lower:
|
1294 |
+
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
|
1295 |
+
else:
|
1296 |
+
prompt_type1 = prompt_type_old
|
1297 |
+
|
1298 |
+
# detranscribe
|
1299 |
+
if lora_weights == no_lora_str:
|
1300 |
+
lora_weights = ''
|
1301 |
+
|
1302 |
+
all_kwargs1['lora_weights'] = lora_weights.strip()
|
1303 |
+
model1, tokenizer1, device1 = get_model(**all_kwargs1)
|
1304 |
+
clear_torch_cache()
|
1305 |
+
|
1306 |
+
if kwargs['debug']:
|
1307 |
+
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1308 |
+
return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
|
1309 |
+
|
1310 |
+
def dropdown_prompt_type_list(x):
|
1311 |
+
return gr.Dropdown.update(value=x)
|
1312 |
+
|
1313 |
+
def chatbot_list(x, model_used_in):
|
1314 |
+
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
|
1315 |
+
|
1316 |
+
load_model_args = dict(fn=load_model,
|
1317 |
+
inputs=[model_choice, lora_choice, model_state, prompt_type,
|
1318 |
+
model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
|
1319 |
+
outputs=[model_state, model_used, lora_used, prompt_type])
|
1320 |
+
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1321 |
+
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1322 |
+
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output_nochat)
|
1323 |
+
if not is_public:
|
1324 |
+
load_model_event = load_model_button.click(**load_model_args) \
|
1325 |
+
.then(**prompt_update_args) \
|
1326 |
+
.then(**chatbot_update_args) \
|
1327 |
+
.then(**nochat_update_args) \
|
1328 |
+
.then(clear_torch_cache)
|
1329 |
+
|
1330 |
+
load_model_args2 = dict(fn=load_model,
|
1331 |
+
inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
|
1332 |
+
model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
|
1333 |
+
outputs=[model_state2, model_used2, lora_used2, prompt_type2])
|
1334 |
+
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
1335 |
+
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
1336 |
+
if not is_public:
|
1337 |
+
load_model_event2 = load_model_button2.click(**load_model_args2) \
|
1338 |
+
.then(**prompt_update_args2) \
|
1339 |
+
.then(**chatbot_update_args2) \
|
1340 |
+
.then(clear_torch_cache)
|
1341 |
+
|
1342 |
+
def dropdown_model_list(list0, x):
|
1343 |
+
new_state = [list0[0] + [x]]
|
1344 |
+
new_options = [*new_state[0]]
|
1345 |
+
return gr.Dropdown.update(value=x, choices=new_options), \
|
1346 |
+
gr.Dropdown.update(value=x, choices=new_options), \
|
1347 |
+
'', new_state
|
1348 |
+
|
1349 |
+
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
1350 |
+
inputs=[model_options_state, new_model],
|
1351 |
+
outputs=[model_choice, model_choice2, new_model, model_options_state])
|
1352 |
+
|
1353 |
+
def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
|
1354 |
+
new_state = [list0[0] + [x]]
|
1355 |
+
new_options = [*new_state[0]]
|
1356 |
+
# don't switch drop-down to added lora if already have model loaded
|
1357 |
+
x1 = x if model_used1 == no_model_str else lora_used1
|
1358 |
+
x2 = x if model_used2 == no_model_str else lora_used2
|
1359 |
+
return gr.Dropdown.update(value=x1, choices=new_options), \
|
1360 |
+
gr.Dropdown.update(value=x2, choices=new_options), \
|
1361 |
+
'', new_state
|
1362 |
+
|
1363 |
+
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
1364 |
+
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, lora_used2],
|
1365 |
+
outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
|
1366 |
+
|
1367 |
+
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
|
1368 |
+
.then(lambda: gr.update(visible=True), None, normal_block) \
|
1369 |
+
.then(**load_model_args).then(**prompt_update_args)
|
1370 |
+
|
1371 |
+
def compare_textbox_fun(x):
|
1372 |
+
return gr.Textbox.update(visible=x)
|
1373 |
+
|
1374 |
+
def compare_column_fun(x):
|
1375 |
+
return gr.Column.update(visible=x)
|
1376 |
+
|
1377 |
+
def compare_prompt_fun(x):
|
1378 |
+
return gr.Dropdown.update(visible=x)
|
1379 |
+
|
1380 |
+
compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, api_name="compare_checkbox") \
|
1381 |
+
.then(compare_column_fun, compare_checkbox, col_model2) \
|
1382 |
+
.then(compare_prompt_fun, compare_checkbox, prompt_type2) \
|
1383 |
+
.then(compare_textbox_fun, compare_checkbox, score_text2)
|
1384 |
+
# FIXME: add score_res2 in condition, but do better
|
1385 |
+
|
1386 |
+
# callback for logging flagged input/output
|
1387 |
+
callback.setup(inputs_list + [text_output], "flagged_data_points")
|
1388 |
+
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1389 |
+
api_name='flag')
|
1390 |
+
flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1391 |
+
api_name='flag_nochat')
|
1392 |
+
|
1393 |
+
def get_system_info():
|
1394 |
+
return gr.Textbox.update(value=system_info_print())
|
1395 |
+
|
1396 |
+
system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
|
1397 |
+
|
1398 |
+
# don't pass text_output, don't want to clear output, just stop it
|
1399 |
+
# FIXME: have to click once to stop output and second time to stop GPUs going
|
1400 |
+
stop_btn.click(lambda: None, None, None,
|
1401 |
+
cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
|
1402 |
+
queue=False, api_name='stop').then(clear_torch_cache)
|
1403 |
+
|
1404 |
+
demo.queue(concurrency_count=1)
|
1405 |
+
favicon_path = "h2o-logo.svg"
|
1406 |
+
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
1407 |
+
favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
|
1408 |
+
print("Started GUI", flush=True)
|
1409 |
+
demo.block_thread()
|
1410 |
+
|
1411 |
+
|
1412 |
+
input_args_list = ['model_state']
|
1413 |
+
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1414 |
+
|
1415 |
+
|
1416 |
+
def get_inputs_list(inputs_dict, model_lower):
|
1417 |
+
"""
|
1418 |
+
map gradio objects in locals() to inputs for evaluate().
|
1419 |
+
:param inputs_dict:
|
1420 |
+
:param model_lower:
|
1421 |
+
:return:
|
1422 |
+
"""
|
1423 |
+
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1424 |
+
inputs_list = []
|
1425 |
+
for k in inputs_list_names:
|
1426 |
+
if k == 'kwargs':
|
1427 |
+
continue
|
1428 |
+
if k in input_args_list + inputs_kwargs_list:
|
1429 |
+
# these are added via partial, not taken as input
|
1430 |
+
continue
|
1431 |
+
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
|
1432 |
+
continue
|
1433 |
+
inputs_list.append(inputs_dict[k])
|
1434 |
+
return inputs_list
|
1435 |
+
|
1436 |
+
|
1437 |
+
eval_func_param_names = ['instruction',
|
1438 |
+
'iinput',
|
1439 |
+
'context',
|
1440 |
+
'stream_output',
|
1441 |
+
'prompt_type',
|
1442 |
+
'temperature',
|
1443 |
+
'top_p',
|
1444 |
+
'top_k',
|
1445 |
+
'num_beams',
|
1446 |
+
'max_new_tokens',
|
1447 |
+
'min_new_tokens',
|
1448 |
+
'early_stopping',
|
1449 |
+
'max_time',
|
1450 |
+
'repetition_penalty',
|
1451 |
+
'num_return_sequences',
|
1452 |
+
'do_sample',
|
1453 |
+
'chat',
|
1454 |
+
'instruction_nochat',
|
1455 |
+
'iinput_nochat',
|
1456 |
+
]
|
1457 |
+
|
1458 |
+
|
1459 |
+
def evaluate(
|
1460 |
+
model_state,
|
1461 |
+
# START NOTE: Examples must have same order of parameters
|
1462 |
+
instruction,
|
1463 |
+
iinput,
|
1464 |
+
context,
|
1465 |
+
stream_output,
|
1466 |
+
prompt_type,
|
1467 |
+
temperature,
|
1468 |
+
top_p,
|
1469 |
+
top_k,
|
1470 |
+
num_beams,
|
1471 |
+
max_new_tokens,
|
1472 |
+
min_new_tokens,
|
1473 |
+
early_stopping,
|
1474 |
+
max_time,
|
1475 |
+
repetition_penalty,
|
1476 |
+
num_return_sequences,
|
1477 |
+
do_sample,
|
1478 |
+
chat,
|
1479 |
+
instruction_nochat,
|
1480 |
+
iinput_nochat,
|
1481 |
+
# END NOTE: Examples must have same order of parameters
|
1482 |
+
src_lang=None,
|
1483 |
+
tgt_lang=None,
|
1484 |
+
debug=False,
|
1485 |
+
save_dir=None,
|
1486 |
+
hard_stop_list=None,
|
1487 |
+
sanitize_bot_response=True,
|
1488 |
+
model_state0=None,
|
1489 |
+
**kwargs,
|
1490 |
+
):
|
1491 |
+
if debug:
|
1492 |
+
locals_dict = locals().copy()
|
1493 |
+
locals_dict.pop('model_state', None)
|
1494 |
+
locals_dict.pop('model_state0', None)
|
1495 |
+
print(locals_dict)
|
1496 |
+
|
1497 |
+
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
1498 |
+
|
1499 |
+
if model_state0 is None:
|
1500 |
+
# e.g. for no gradio case, set dummy value, else should be set
|
1501 |
+
model_state0 = [None, None, None, None]
|
1502 |
+
|
1503 |
+
if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
|
1504 |
+
# try to free-up original model (i.e. list was passed as reference)
|
1505 |
+
if model_state0 is not None and model_state0[0] is not None:
|
1506 |
+
model_state0[0].cpu()
|
1507 |
+
model_state0[0] = None
|
1508 |
+
# try to free-up original tokenizer (i.e. list was passed as reference)
|
1509 |
+
if model_state0 is not None and model_state0[1] is not None:
|
1510 |
+
model_state0[1] = None
|
1511 |
+
clear_torch_cache()
|
1512 |
+
model, tokenizer, device, base_model = model_state
|
1513 |
+
elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
|
1514 |
+
assert isinstance(model_state[0], str)
|
1515 |
+
model, tokenizer, device, base_model = model_state0
|
1516 |
+
else:
|
1517 |
+
raise AssertionError(no_model_msg)
|
1518 |
+
|
1519 |
+
if base_model is None:
|
1520 |
+
raise AssertionError(no_model_msg)
|
1521 |
+
|
1522 |
+
assert base_model.strip(), no_model_msg
|
1523 |
+
assert model, "Model is missing"
|
1524 |
+
assert tokenizer, "Tokenizer is missing"
|
1525 |
+
|
1526 |
+
# choose chat or non-chat mode
|
1527 |
+
if not chat:
|
1528 |
+
instruction = instruction_nochat
|
1529 |
+
iinput = iinput_nochat
|
1530 |
+
|
1531 |
+
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1532 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
1533 |
+
prompt = prompter.generate_prompt(data_point)
|
1534 |
+
|
1535 |
+
if hard_stop_list is None:
|
1536 |
+
# acts like undo on user entry and bot response
|
1537 |
+
hard_stop_list = []
|
1538 |
+
|
1539 |
+
if isinstance(tokenizer, str):
|
1540 |
+
# pipeline
|
1541 |
+
if tokenizer == "summarization":
|
1542 |
+
key = 'summary_text'
|
1543 |
+
else:
|
1544 |
+
raise RuntimeError("No such task type %s" % tokenizer)
|
1545 |
+
# NOTE: uses max_length only
|
1546 |
+
yield model(prompt, max_length=max_new_tokens)[0][key]
|
1547 |
+
|
1548 |
+
if 'mbart-' in base_model.lower():
|
1549 |
+
assert src_lang is not None
|
1550 |
+
tokenizer.src_lang = languages_covered()[src_lang]
|
1551 |
+
|
1552 |
+
if chat:
|
1553 |
+
# override, ignore user change
|
1554 |
+
num_return_sequences = 1
|
1555 |
+
if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
|
1556 |
+
if prompt_type == 'human_bot':
|
1557 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1558 |
+
# stopping only starts once output is beyond prompt
|
1559 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1560 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
1561 |
+
encounters = [1, 2]
|
1562 |
+
elif prompt_type == 'instruct_vicuna':
|
1563 |
+
# even below is not enough, generic strings and many ways to encode
|
1564 |
+
stop_words = [
|
1565 |
+
'### Human:',
|
1566 |
+
"""
|
1567 |
+
### Human:""",
|
1568 |
+
"""
|
1569 |
+
### Human:
|
1570 |
+
""",
|
1571 |
+
'### Assistant:',
|
1572 |
+
"""
|
1573 |
+
### Assistant:""",
|
1574 |
+
"""
|
1575 |
+
### Assistant:
|
1576 |
+
""",
|
1577 |
+
]
|
1578 |
+
encounters = [1, 2]
|
1579 |
+
else:
|
1580 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
1581 |
+
stop_words = ['### End']
|
1582 |
+
encounters = [1]
|
1583 |
+
stop_words_ids = [
|
1584 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
1585 |
+
# handle single token case
|
1586 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
1587 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
1588 |
+
# avoid padding in front of tokens
|
1589 |
+
if tokenizer.pad_token:
|
1590 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1591 |
+
# handle fake \n added
|
1592 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
1593 |
+
# build stopper
|
1594 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1595 |
+
else:
|
1596 |
+
stopping_criteria = StoppingCriteriaList()
|
1597 |
+
|
1598 |
+
# help to avoid errors like:
|
1599 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1600 |
+
# RuntimeError: expected scalar type Half but found Float
|
1601 |
+
# with - 256
|
1602 |
+
max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
|
1603 |
+
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1604 |
+
output_smallest = 30 * 4
|
1605 |
+
prompt = prompt[-cutoff_len - output_smallest:]
|
1606 |
+
inputs = tokenizer(prompt,
|
1607 |
+
return_tensors="pt",
|
1608 |
+
truncation=True,
|
1609 |
+
max_length=max_length_tokenize)
|
1610 |
+
if debug and len(inputs["input_ids"]) > 0:
|
1611 |
+
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1612 |
+
input_ids = inputs["input_ids"].to(device)
|
1613 |
+
generation_config = GenerationConfig(
|
1614 |
+
temperature=float(temperature),
|
1615 |
+
top_p=float(top_p),
|
1616 |
+
top_k=top_k,
|
1617 |
+
num_beams=num_beams,
|
1618 |
+
do_sample=do_sample,
|
1619 |
+
repetition_penalty=float(repetition_penalty),
|
1620 |
+
num_return_sequences=num_return_sequences,
|
1621 |
+
renormalize_logits=True,
|
1622 |
+
remove_invalid_values=True,
|
1623 |
+
**kwargs,
|
1624 |
+
)
|
1625 |
+
|
1626 |
+
gen_kwargs = dict(input_ids=input_ids,
|
1627 |
+
generation_config=generation_config,
|
1628 |
+
return_dict_in_generate=True,
|
1629 |
+
output_scores=True,
|
1630 |
+
max_new_tokens=max_new_tokens, # prompt + new
|
1631 |
+
min_new_tokens=min_new_tokens, # prompt + new
|
1632 |
+
early_stopping=early_stopping, # False, True, "never"
|
1633 |
+
max_time=max_time,
|
1634 |
+
stopping_criteria=stopping_criteria,
|
1635 |
+
)
|
1636 |
+
if 'gpt2' in base_model.lower():
|
1637 |
+
gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
|
1638 |
+
elif 'mbart-' in base_model.lower():
|
1639 |
+
assert tgt_lang is not None
|
1640 |
+
tgt_lang = languages_covered()[tgt_lang]
|
1641 |
+
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
|
1642 |
+
else:
|
1643 |
+
gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
|
1644 |
+
|
1645 |
+
decoder = functools.partial(tokenizer.decode,
|
1646 |
+
skip_special_tokens=True,
|
1647 |
+
clean_up_tokenization_spaces=True,
|
1648 |
+
)
|
1649 |
+
decoder_raw = functools.partial(tokenizer.decode,
|
1650 |
+
skip_special_tokens=False,
|
1651 |
+
clean_up_tokenization_spaces=True,
|
1652 |
+
)
|
1653 |
+
|
1654 |
+
with torch.no_grad():
|
1655 |
+
# decoded tokenized prompt can deviate from prompt due to special characters
|
1656 |
+
inputs_decoded = decoder(input_ids[0])
|
1657 |
+
inputs_decoded_raw = decoder_raw(input_ids[0])
|
1658 |
+
if inputs_decoded == prompt:
|
1659 |
+
# normal
|
1660 |
+
pass
|
1661 |
+
elif inputs_decoded.lstrip() == prompt.lstrip():
|
1662 |
+
# sometimes extra space in front, make prompt same for prompt removal
|
1663 |
+
prompt = inputs_decoded
|
1664 |
+
elif inputs_decoded_raw == prompt:
|
1665 |
+
# some models specify special tokens that are part of normal prompt, so can't skip them
|
1666 |
+
inputs_decoded_raw = inputs_decoded
|
1667 |
+
decoder = decoder_raw
|
1668 |
+
else:
|
1669 |
+
print("WARNING: Special characters in prompt", flush=True)
|
1670 |
+
if stream_output:
|
1671 |
+
def generate(callback=None, **kwargs):
|
1672 |
+
# re-order stopping so Stream first and get out all chunks before stop for other reasons
|
1673 |
+
stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
|
1674 |
+
kwargs['stopping_criteria'] = StoppingCriteriaList()
|
1675 |
+
kwargs['stopping_criteria'].append(Stream(func=callback))
|
1676 |
+
for stopping_criteria1 in stopping_criteria0:
|
1677 |
+
kwargs['stopping_criteria'].append(stopping_criteria1)
|
1678 |
+
|
1679 |
+
try:
|
1680 |
+
model.generate(**kwargs)
|
1681 |
+
except torch.cuda.OutOfMemoryError as e:
|
1682 |
+
print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1683 |
+
flush=True)
|
1684 |
+
if kwargs['input_ids'] is not None:
|
1685 |
+
kwargs['input_ids'].cpu()
|
1686 |
+
kwargs['input_ids'] = None
|
1687 |
+
traceback.print_exc()
|
1688 |
+
clear_torch_cache()
|
1689 |
+
return
|
1690 |
+
except (Exception, RuntimeError) as e:
|
1691 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
1692 |
+
'expected scalar type Half but found Float' in str(e) or \
|
1693 |
+
'probability tensor contains either' in str(e) or \
|
1694 |
+
'cublasLt ran into an error!' in str(e):
|
1695 |
+
print(
|
1696 |
+
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1697 |
+
flush=True)
|
1698 |
+
traceback.print_exc()
|
1699 |
+
clear_torch_cache()
|
1700 |
+
if raise_generate_gpu_exceptions:
|
1701 |
+
raise
|
1702 |
+
return
|
1703 |
+
else:
|
1704 |
+
raise
|
1705 |
+
|
1706 |
+
decoded_output = None
|
1707 |
+
for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
|
1708 |
+
decoded_output = decoder(output)
|
1709 |
+
if output[-1] in [tokenizer.eos_token_id]:
|
1710 |
+
if debug:
|
1711 |
+
print("HIT EOS", flush=True)
|
1712 |
+
break
|
1713 |
+
if any(ele in decoded_output for ele in hard_stop_list):
|
1714 |
+
raise StopIteration
|
1715 |
+
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1716 |
+
sanitize_bot_response=sanitize_bot_response)
|
1717 |
+
if save_dir and decoded_output:
|
1718 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1719 |
+
else:
|
1720 |
+
outputs = model.generate(**gen_kwargs)
|
1721 |
+
outputs = [decoder(s) for s in outputs.sequences]
|
1722 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1723 |
+
sanitize_bot_response=sanitize_bot_response)
|
1724 |
+
if save_dir and outputs and len(outputs) >= 1:
|
1725 |
+
decoded_output = prompt + outputs[0]
|
1726 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1727 |
+
|
1728 |
+
|
1729 |
+
def get_generate_params(model_lower, chat,
|
1730 |
+
stream_output, show_examples,
|
1731 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
1732 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
1733 |
+
repetition_penalty, num_return_sequences,
|
1734 |
+
do_sample):
|
1735 |
+
use_defaults = False
|
1736 |
+
use_default_examples = True
|
1737 |
+
examples = []
|
1738 |
+
task_info = f"{prompt_type}"
|
1739 |
+
if model_lower:
|
1740 |
+
print(f"Using Model {model_lower}", flush=True)
|
1741 |
+
else:
|
1742 |
+
print("No model defined yet", flush=True)
|
1743 |
+
|
1744 |
+
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
|
1745 |
+
early_stopping = early_stopping if early_stopping is not None else False
|
1746 |
+
max_time_defaults = 60 * 3
|
1747 |
+
max_time = max_time if max_time is not None else max_time_defaults
|
1748 |
+
|
1749 |
+
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
1750 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
1751 |
+
|
1752 |
+
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
1753 |
+
if show_examples is None:
|
1754 |
+
if chat:
|
1755 |
+
show_examples = False
|
1756 |
+
else:
|
1757 |
+
show_examples = True
|
1758 |
+
|
1759 |
+
summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
|
1760 |
+
Philipp: Sure you can use the new Hugging Face Deep Learning Container.
|
1761 |
+
Jeff: ok.
|
1762 |
+
Jeff: and how can I get started?
|
1763 |
+
Jeff: where can I find documentation?
|
1764 |
+
Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
|
1765 |
+
|
1766 |
+
if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
|
1767 |
+
placeholder_instruction = summarize_example1
|
1768 |
+
placeholder_input = ""
|
1769 |
+
use_defaults = True
|
1770 |
+
use_default_examples = False
|
1771 |
+
examples += [
|
1772 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1773 |
+
1.0, 1,
|
1774 |
+
False]]
|
1775 |
+
task_info = "Summarization"
|
1776 |
+
elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
|
1777 |
+
placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
|
1778 |
+
placeholder_input = ""
|
1779 |
+
use_defaults = True
|
1780 |
+
use_default_examples = True
|
1781 |
+
task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
|
1782 |
+
elif 'mbart-' in model_lower:
|
1783 |
+
placeholder_instruction = "The girl has long hair."
|
1784 |
+
placeholder_input = ""
|
1785 |
+
use_defaults = True
|
1786 |
+
use_default_examples = False
|
1787 |
+
examples += [
|
1788 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1789 |
+
1.0, 1,
|
1790 |
+
False]]
|
1791 |
+
elif 'gpt2' in model_lower:
|
1792 |
+
placeholder_instruction = "The sky is"
|
1793 |
+
placeholder_input = ""
|
1794 |
+
prompt_type = prompt_type or 'plain'
|
1795 |
+
use_default_examples = True # some will be odd "continuations" but can be ok
|
1796 |
+
examples += [
|
1797 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1798 |
+
1.0, 1,
|
1799 |
+
False]]
|
1800 |
+
task_info = "Auto-complete phrase, code, etc."
|
1801 |
+
use_defaults = True
|
1802 |
+
else:
|
1803 |
+
if chat:
|
1804 |
+
placeholder_instruction = "Enter a question or imperative."
|
1805 |
+
else:
|
1806 |
+
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
1807 |
+
placeholder_input = ""
|
1808 |
+
if model_lower:
|
1809 |
+
prompt_type = prompt_type or 'human_bot'
|
1810 |
+
else:
|
1811 |
+
prompt_type = ''
|
1812 |
+
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
1813 |
+
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
|
1814 |
+
False]]
|
1815 |
+
task_info = "No task"
|
1816 |
+
if prompt_type == 'instruct':
|
1817 |
+
task_info = "Answer question or follow imperative as instruction with optionally input."
|
1818 |
+
elif prompt_type == 'plain':
|
1819 |
+
task_info = "Auto-complete phrase, code, etc."
|
1820 |
+
elif prompt_type == 'human_bot':
|
1821 |
+
if chat:
|
1822 |
+
task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
|
1823 |
+
else:
|
1824 |
+
task_info = "Ask question/imperative (input concatenated with instruction)"
|
1825 |
+
|
1826 |
+
# revert to plain if still nothing
|
1827 |
+
prompt_type = prompt_type or 'plain'
|
1828 |
+
if use_defaults:
|
1829 |
+
temperature = 1.0 if temperature is None else temperature
|
1830 |
+
top_p = 1.0 if top_p is None else top_p
|
1831 |
+
top_k = 40 if top_k is None else top_k
|
1832 |
+
num_beams = num_beams or 1
|
1833 |
+
max_new_tokens = max_new_tokens or 128
|
1834 |
+
repetition_penalty = repetition_penalty or 1.07
|
1835 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1836 |
+
do_sample = False if do_sample is None else do_sample
|
1837 |
+
else:
|
1838 |
+
temperature = 0.1 if temperature is None else temperature
|
1839 |
+
top_p = 0.75 if top_p is None else top_p
|
1840 |
+
top_k = 40 if top_k is None else top_k
|
1841 |
+
if chat:
|
1842 |
+
num_beams = num_beams or 1
|
1843 |
+
else:
|
1844 |
+
num_beams = num_beams or 4
|
1845 |
+
max_new_tokens = max_new_tokens or 256
|
1846 |
+
repetition_penalty = repetition_penalty or 1.07
|
1847 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1848 |
+
do_sample = False if do_sample is None else do_sample
|
1849 |
+
# doesn't include chat, instruction_nochat, iinput_nochat, added later
|
1850 |
+
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
|
1851 |
+
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
|
1852 |
+
|
1853 |
+
if use_default_examples:
|
1854 |
+
examples += [
|
1855 |
+
["Translate English to French", "Good morning"] + params_list,
|
1856 |
+
["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
|
1857 |
+
["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
|
1858 |
+
[
|
1859 |
+
"Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
|
1860 |
+
''] + params_list,
|
1861 |
+
['Translate to German: My name is Arthur', ''] + params_list,
|
1862 |
+
["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
|
1863 |
+
['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
|
1864 |
+
''] + params_list,
|
1865 |
+
['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
|
1866 |
+
['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
|
1867 |
+
["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
|
1868 |
+
[
|
1869 |
+
"Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
|
1870 |
+
''] + params_list,
|
1871 |
+
['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
|
1872 |
+
[
|
1873 |
+
'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
|
1874 |
+
''] + params_list,
|
1875 |
+
["""def area_of_rectangle(a: float, b: float):
|
1876 |
+
\"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
|
1877 |
+
["""# a function in native python:
|
1878 |
+
def mean(a):
|
1879 |
+
return sum(a)/len(a)
|
1880 |
+
|
1881 |
+
# the same function using numpy:
|
1882 |
+
import numpy as np
|
1883 |
+
def mean(a):""", ''] + params_list,
|
1884 |
+
["""X = np.random.randn(100, 100)
|
1885 |
+
y = np.random.randint(0, 1, 100)
|
1886 |
+
|
1887 |
+
# fit random forest classifier with 20 estimators""", ''] + params_list,
|
1888 |
+
]
|
1889 |
+
|
1890 |
+
src_lang = "English"
|
1891 |
+
tgt_lang = "Russian"
|
1892 |
+
|
1893 |
+
# move to correct position
|
1894 |
+
for example in examples:
|
1895 |
+
example += [chat, '', '']
|
1896 |
+
# adjust examples if non-chat mode
|
1897 |
+
if not chat:
|
1898 |
+
example[eval_func_param_names.index('instruction_nochat')] = example[
|
1899 |
+
eval_func_param_names.index('instruction')]
|
1900 |
+
example[eval_func_param_names.index('instruction')] = ''
|
1901 |
+
|
1902 |
+
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
|
1903 |
+
example[eval_func_param_names.index('iinput')] = ''
|
1904 |
+
|
1905 |
+
return placeholder_instruction, placeholder_input, \
|
1906 |
+
stream_output, show_examples, \
|
1907 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
1908 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
1909 |
+
repetition_penalty, num_return_sequences, \
|
1910 |
+
do_sample, \
|
1911 |
+
src_lang, tgt_lang, \
|
1912 |
+
examples, \
|
1913 |
+
task_info
|
1914 |
+
|
1915 |
+
|
1916 |
+
def languages_covered():
|
1917 |
+
# https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
|
1918 |
+
covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
|
1919 |
+
covered = covered.split(', ')
|
1920 |
+
covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
|
1921 |
+
return covered
|
1922 |
+
|
1923 |
+
|
1924 |
+
def test_test_prompt(prompt_type='instruct', data_point=0):
|
1925 |
+
example_data_point = example_data_points[data_point]
|
1926 |
+
example_data_point.pop('output', None)
|
1927 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
1928 |
+
|
1929 |
+
|
1930 |
+
if __name__ == "__main__":
|
1931 |
+
print("""
|
1932 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
|
1933 |
+
python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
|
1934 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
|
1935 |
+
|
1936 |
+
# generate without lora weights, no prompt
|
1937 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
|
1938 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
|
1939 |
+
|
1940 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
|
1941 |
+
# OpenChatKit settings:
|
1942 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
|
1943 |
+
|
1944 |
+
python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
|
1945 |
+
python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
|
1946 |
+
python generate.py --base_model='philschmid/bart-large-cnn-samsum'
|
1947 |
+
python generate.py --base_model='philschmid/flan-t5-base-samsum'
|
1948 |
+
python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
|
1949 |
+
|
1950 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1951 |
+
|
1952 |
+
must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
|
1953 |
+
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1954 |
+
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1955 |
+
|
1956 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
|
1957 |
+
|
1958 |
+
""", flush=True)
|
1959 |
+
fire.Fire(main)
|
client_test.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Client test.
|
3 |
+
|
4 |
+
Run server:
|
5 |
+
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
|
7 |
+
|
8 |
+
NOTE: For private models, add --use-auth_token=True
|
9 |
+
|
10 |
+
NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
+
Currently, this will force model to be on a single GPU.
|
12 |
+
|
13 |
+
Then run this client as:
|
14 |
+
|
15 |
+
python client_test.py
|
16 |
+
"""
|
17 |
+
|
18 |
+
debug = False
|
19 |
+
|
20 |
+
import os
|
21 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
22 |
+
from gradio_client import Client
|
23 |
+
|
24 |
+
client = Client("http://localhost:7860")
|
25 |
+
if debug:
|
26 |
+
print(client.view_api(all_endpoints=True))
|
27 |
+
|
28 |
+
instruction = '' # only for chat=True
|
29 |
+
iinput = '' # only for chat=True
|
30 |
+
context = ''
|
31 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
32 |
+
# but leave stream_output=False for simple input/output mode
|
33 |
+
stream_output = False
|
34 |
+
prompt_type = 'human_bot'
|
35 |
+
temperature = 0.1
|
36 |
+
top_p = 0.75
|
37 |
+
top_k = 40
|
38 |
+
num_beams = 1
|
39 |
+
max_new_tokens = 50
|
40 |
+
min_new_tokens = 0
|
41 |
+
early_stopping = False
|
42 |
+
max_time = 20
|
43 |
+
repetition_penalty = 1.0
|
44 |
+
num_return_sequences = 1
|
45 |
+
do_sample = True
|
46 |
+
# only these 2 below used if pass chat=False
|
47 |
+
chat = False
|
48 |
+
instruction_nochat = "Who are you?"
|
49 |
+
iinput_nochat = ''
|
50 |
+
|
51 |
+
|
52 |
+
def test_client_basic():
|
53 |
+
args = [instruction,
|
54 |
+
iinput,
|
55 |
+
context,
|
56 |
+
stream_output,
|
57 |
+
prompt_type,
|
58 |
+
temperature,
|
59 |
+
top_p,
|
60 |
+
top_k,
|
61 |
+
num_beams,
|
62 |
+
max_new_tokens,
|
63 |
+
min_new_tokens,
|
64 |
+
early_stopping,
|
65 |
+
max_time,
|
66 |
+
repetition_penalty,
|
67 |
+
num_return_sequences,
|
68 |
+
do_sample,
|
69 |
+
chat,
|
70 |
+
instruction_nochat,
|
71 |
+
iinput_nochat,
|
72 |
+
]
|
73 |
+
api_name = '/submit_nochat'
|
74 |
+
res = client.predict(
|
75 |
+
*tuple(args),
|
76 |
+
api_name=api_name,
|
77 |
+
)
|
78 |
+
res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
|
79 |
+
print(res_dict)
|
80 |
+
|
81 |
+
|
82 |
+
import markdown # pip install markdown
|
83 |
+
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
84 |
+
|
85 |
+
|
86 |
+
def md_to_text(md):
|
87 |
+
html = markdown.markdown(md)
|
88 |
+
soup = BeautifulSoup(html, features='html.parser')
|
89 |
+
return soup.get_text()
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
test_client_basic()
|
finetune.py
ADDED
@@ -0,0 +1,934 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import List, Union
|
10 |
+
import fire
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset, concatenate_datasets
|
14 |
+
import transformers
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
from peft import (
|
18 |
+
prepare_model_for_int8_training,
|
19 |
+
LoraConfig,
|
20 |
+
get_peft_model,
|
21 |
+
get_peft_model_state_dict,
|
22 |
+
set_peft_model_state_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
from peft import mapping
|
26 |
+
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
27 |
+
|
28 |
+
|
29 |
+
def log(*args, **kwargs):
|
30 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
31 |
+
print(*args, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
try:
|
35 |
+
import neptune
|
36 |
+
from transformers.integrations import NeptuneCallback
|
37 |
+
|
38 |
+
neptune_run = neptune.init_run(
|
39 |
+
source_files=[],
|
40 |
+
)
|
41 |
+
log("Connected to Neptune.")
|
42 |
+
except ImportError:
|
43 |
+
neptune_run = None
|
44 |
+
log("Please pip install neptune for tracking.")
|
45 |
+
except neptune.exceptions.NeptuneMissingApiTokenException:
|
46 |
+
neptune_run = None
|
47 |
+
os.environ["NEPTUNE_MODE"] = 'debug'
|
48 |
+
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
49 |
+
|
50 |
+
from enum import Enum
|
51 |
+
|
52 |
+
|
53 |
+
class PromptType(Enum):
|
54 |
+
plain = 0
|
55 |
+
instruct = 1
|
56 |
+
quality = 2
|
57 |
+
human_bot = 3
|
58 |
+
dai_faq = 4
|
59 |
+
summarize = 5
|
60 |
+
simple_instruct = 6
|
61 |
+
instruct_vicuna = 7
|
62 |
+
instruct_with_end = 8
|
63 |
+
human_bot_orig = 9
|
64 |
+
|
65 |
+
|
66 |
+
prompt_type_to_model_name = {
|
67 |
+
'plain': [
|
68 |
+
'EleutherAI/gpt-j-6B',
|
69 |
+
'EleutherAI/pythia-6.9b',
|
70 |
+
'EleutherAI/pythia-12b',
|
71 |
+
'EleutherAI/pythia-12b-deduped',
|
72 |
+
'EleutherAI/gpt-neox-20b',
|
73 |
+
'decapoda-research/llama-7b-hf',
|
74 |
+
'decapoda-research/llama-13b-hf',
|
75 |
+
'decapoda-research/llama-30b-hf',
|
76 |
+
'decapoda-research/llama-65b-hf',
|
77 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
78 |
+
'philschmid/bart-large-cnn-samsum',
|
79 |
+
'philschmid/flan-t5-base-samsum',
|
80 |
+
'gpt2',
|
81 |
+
'distilgpt2',
|
82 |
+
],
|
83 |
+
'instruct': [],
|
84 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
85 |
+
'quality': [],
|
86 |
+
'human_bot': [
|
87 |
+
'h2oai/h2ogpt-oig-oasst1-256-12b',
|
88 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
89 |
+
'h2oai/h2ogpt-oasst1-256-20b',
|
90 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
91 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b',
|
92 |
+
],
|
93 |
+
'dai_faq': [],
|
94 |
+
'summarize': [],
|
95 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
96 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
97 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
98 |
+
}
|
99 |
+
|
100 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
101 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
102 |
+
|
103 |
+
human = '<human>:'
|
104 |
+
bot = "<bot>:"
|
105 |
+
|
106 |
+
prompt_types_strings = []
|
107 |
+
for p in PromptType:
|
108 |
+
prompt_types_strings.extend([p.name])
|
109 |
+
|
110 |
+
|
111 |
+
prompt_types = []
|
112 |
+
for p in PromptType:
|
113 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
114 |
+
|
115 |
+
|
116 |
+
# supported by huggingface evaluate
|
117 |
+
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
118 |
+
|
119 |
+
|
120 |
+
def train(
|
121 |
+
save_code: bool = False,
|
122 |
+
run_id: int = None,
|
123 |
+
|
124 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
125 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
126 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
127 |
+
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
128 |
+
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
129 |
+
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
130 |
+
# base_model: str = 'decapoda-research/llama-7b-hf',
|
131 |
+
# base_model: str = 'decapoda-research/llama-13b-hf',
|
132 |
+
# base_model: str = 'decapoda-research/llama-30b-hf',
|
133 |
+
# base_model: str = 'EleutherAI/gpt-j-6B',
|
134 |
+
|
135 |
+
# only needed if base_model is self-exported HF state without tokenizer
|
136 |
+
tokenizer_base_model: str = None,
|
137 |
+
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
138 |
+
|
139 |
+
data_path: str = None,
|
140 |
+
data_col_dict: dict = None,
|
141 |
+
# data_path: str = "./dai_docs.train.json",
|
142 |
+
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
143 |
+
|
144 |
+
valid_path: str = None,
|
145 |
+
# valid_path: str = "./dai_docs.valid.json",
|
146 |
+
|
147 |
+
# data_mix_in_path: str = "laion/OIG", # way too big, medium quality
|
148 |
+
data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
|
149 |
+
data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
|
150 |
+
data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
|
151 |
+
data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
|
152 |
+
|
153 |
+
output_dir: str = None,
|
154 |
+
|
155 |
+
# LoRA checkpoint continuation
|
156 |
+
lora_weights: str = "",
|
157 |
+
|
158 |
+
# batching training hyperparams
|
159 |
+
batch_size: int = 128,
|
160 |
+
micro_batch_size: int = 4,
|
161 |
+
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
162 |
+
fp16=True,
|
163 |
+
|
164 |
+
# general training hyperparams
|
165 |
+
num_epochs: float = 1,
|
166 |
+
learning_rate: float = 3e-4,
|
167 |
+
|
168 |
+
# validation settings
|
169 |
+
val_set_size: int = None,
|
170 |
+
val_metrics: List[str] = [],
|
171 |
+
eval_steps: int = None, # to control eval steps via steps
|
172 |
+
eval_epochs: float = None, # to control eval steps via epochs
|
173 |
+
|
174 |
+
# lora hyperparams
|
175 |
+
lora_r: int = 8,
|
176 |
+
lora_alpha: int = 16,
|
177 |
+
lora_dropout: float = 0.05,
|
178 |
+
lora_target_modules: List[str] = None,
|
179 |
+
llama_type: bool = None,
|
180 |
+
|
181 |
+
# llm hyperparams
|
182 |
+
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
183 |
+
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
184 |
+
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
185 |
+
cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
|
186 |
+
|
187 |
+
# torch training params
|
188 |
+
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
189 |
+
local_files_only: bool = False, # else will download new versions, normally unwanted
|
190 |
+
resume_download: bool = True,
|
191 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
192 |
+
warmup_steps: int = 100,
|
193 |
+
logging_steps: int = 1,
|
194 |
+
save_steps: int = None, # must be round multiple of eval_steps
|
195 |
+
add_eos_token: bool = False,
|
196 |
+
):
|
197 |
+
# allow set token directly
|
198 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
199 |
+
|
200 |
+
prompt_type = str(prompt_type) # migration from integers
|
201 |
+
assert prompt_type in prompt_types
|
202 |
+
|
203 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
204 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
205 |
+
rank = int(os.getenv("RANK", 0))
|
206 |
+
print(f"local_rank: {local_rank}")
|
207 |
+
print(f"global rank: {rank}")
|
208 |
+
|
209 |
+
gpus = max(world_size, torch.cuda.device_count())
|
210 |
+
run_id = run_id or 0
|
211 |
+
if not data_path:
|
212 |
+
raise ValueError("No data_path provided")
|
213 |
+
if not output_dir:
|
214 |
+
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
215 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
216 |
+
raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
|
217 |
+
else:
|
218 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
219 |
+
raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
220 |
+
device_map = "auto"
|
221 |
+
|
222 |
+
if save_code:
|
223 |
+
copy_code(run_id)
|
224 |
+
if tokenizer_base_model is None:
|
225 |
+
tokenizer_base_model = base_model
|
226 |
+
if llama_type is None:
|
227 |
+
llama_type = "llama" in base_model.lower()
|
228 |
+
assert (
|
229 |
+
base_model
|
230 |
+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
231 |
+
gradient_accumulation_steps = batch_size // micro_batch_size
|
232 |
+
assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
|
233 |
+
|
234 |
+
device_map = "auto"
|
235 |
+
|
236 |
+
locals_dict = locals()
|
237 |
+
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
238 |
+
log(f"Training model with params:\n{locals_print}")
|
239 |
+
log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
|
240 |
+
|
241 |
+
max_memory = None
|
242 |
+
if gpus > 1:
|
243 |
+
if ddp:
|
244 |
+
log("Distributed: data parallel")
|
245 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
246 |
+
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
247 |
+
else:
|
248 |
+
free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
|
249 |
+
max_memory = f"{free_in_GB - 2}GB"
|
250 |
+
max_memory = {i: max_memory for i in range(gpus)}
|
251 |
+
log("world_size: %d" % world_size)
|
252 |
+
log("num_gpus: %d" % gpus)
|
253 |
+
log("max mem: %s" % max_memory)
|
254 |
+
|
255 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
|
256 |
+
|
257 |
+
model = model_loader.from_pretrained(
|
258 |
+
base_model,
|
259 |
+
load_in_8bit=True,
|
260 |
+
device_map=device_map,
|
261 |
+
torch_dtype=torch.float16,
|
262 |
+
max_memory=max_memory,
|
263 |
+
local_files_only=local_files_only,
|
264 |
+
resume_download=resume_download,
|
265 |
+
use_auth_token=use_auth_token,
|
266 |
+
)
|
267 |
+
if gpus > 1:
|
268 |
+
if not ddp:
|
269 |
+
log("model parallel")
|
270 |
+
model.is_parallelizable = True
|
271 |
+
model.model_parallel = True
|
272 |
+
|
273 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
274 |
+
local_files_only=local_files_only,
|
275 |
+
resume_download=resume_download,
|
276 |
+
use_auth_token=use_auth_token)
|
277 |
+
|
278 |
+
tokenizer.pad_token_id = 0 # different from the eos token
|
279 |
+
# when generating, we will use the logits of right-most token to predict the next token
|
280 |
+
# so the padding should be on the left,
|
281 |
+
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
282 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
283 |
+
|
284 |
+
def tokenize(prompt, add_eos_token=True):
|
285 |
+
# there's probably a way to do this with the tokenizer settings
|
286 |
+
# but again, gotta move fast
|
287 |
+
result = tokenizer(
|
288 |
+
prompt,
|
289 |
+
truncation=True,
|
290 |
+
max_length=cutoff_len,
|
291 |
+
padding=False,
|
292 |
+
return_tensors=None,
|
293 |
+
)
|
294 |
+
if (
|
295 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
296 |
+
and len(result["input_ids"]) < cutoff_len
|
297 |
+
and add_eos_token
|
298 |
+
):
|
299 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
300 |
+
result["attention_mask"].append(1)
|
301 |
+
|
302 |
+
result["labels"] = result["input_ids"].copy()
|
303 |
+
|
304 |
+
return result
|
305 |
+
|
306 |
+
def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
|
307 |
+
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
308 |
+
tokenized_full_prompt = tokenize(full_prompt)
|
309 |
+
if not train_on_inputs:
|
310 |
+
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
311 |
+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
|
312 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
313 |
+
if add_eos:
|
314 |
+
user_prompt_len -= 1
|
315 |
+
|
316 |
+
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
317 |
+
tokenized_full_prompt["labels"] = [
|
318 |
+
-100
|
319 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
320 |
+
user_prompt_len:
|
321 |
+
] # could be sped up, probably
|
322 |
+
return tokenized_full_prompt
|
323 |
+
|
324 |
+
if "gpt-neox" not in base_model or True:
|
325 |
+
model = prepare_model_for_int8_training(model)
|
326 |
+
else:
|
327 |
+
model = prepare_model_for_int8_training(
|
328 |
+
model,
|
329 |
+
output_embedding_layer_name="embed_out", # keep output logits in float32
|
330 |
+
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
331 |
+
)
|
332 |
+
if lora_weights:
|
333 |
+
from peft import PeftModel
|
334 |
+
model = PeftModel.from_pretrained(
|
335 |
+
model,
|
336 |
+
lora_weights,
|
337 |
+
torch_dtype=torch.float16,
|
338 |
+
device_map=device_map,
|
339 |
+
local_files_only=local_files_only,
|
340 |
+
resume_download=resume_download,
|
341 |
+
use_auth_token=use_auth_token,
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
if lora_target_modules is None:
|
345 |
+
base_model_lower = base_model.lower()
|
346 |
+
if base_model_lower in lora_mappings:
|
347 |
+
lora_target_modules_cand = [lora_mappings[base_model_lower]]
|
348 |
+
else:
|
349 |
+
lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
|
350 |
+
else:
|
351 |
+
lora_target_modules_cand = [lora_target_modules]
|
352 |
+
|
353 |
+
for lora_target_modules in lora_target_modules_cand:
|
354 |
+
try:
|
355 |
+
config = LoraConfig(
|
356 |
+
r=lora_r,
|
357 |
+
lora_alpha=lora_alpha,
|
358 |
+
target_modules=lora_target_modules,
|
359 |
+
lora_dropout=lora_dropout,
|
360 |
+
bias="none",
|
361 |
+
task_type="CAUSAL_LM",
|
362 |
+
)
|
363 |
+
model = get_peft_model(model, config)
|
364 |
+
break
|
365 |
+
except ValueError as e:
|
366 |
+
if "Target modules" in str(e) and "not found" in str(e):
|
367 |
+
continue
|
368 |
+
else:
|
369 |
+
raise
|
370 |
+
from peft import PeftModel
|
371 |
+
assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
|
372 |
+
if resume_from_checkpoint:
|
373 |
+
# Check the available weights and load them
|
374 |
+
checkpoint_name = os.path.join(
|
375 |
+
resume_from_checkpoint, "pytorch_model.bin"
|
376 |
+
) # Full checkpoint
|
377 |
+
if not os.path.exists(checkpoint_name):
|
378 |
+
checkpoint_name = os.path.join(
|
379 |
+
resume_from_checkpoint, "adapter_model.bin"
|
380 |
+
) # only LoRA model - LoRA config above has to fit
|
381 |
+
resume_from_checkpoint = False # So the trainer won't try loading its state
|
382 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
383 |
+
if os.path.exists(checkpoint_name):
|
384 |
+
log(f"Restarting from {checkpoint_name}")
|
385 |
+
adapters_weights = torch.load(checkpoint_name)
|
386 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
387 |
+
else:
|
388 |
+
log(f"Checkpoint {checkpoint_name} not found")
|
389 |
+
|
390 |
+
print(model)
|
391 |
+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
392 |
+
|
393 |
+
metrics = {}
|
394 |
+
for name in supported_metrics:
|
395 |
+
if name in val_metrics:
|
396 |
+
import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
|
397 |
+
metrics[name] = evaluate.load(name)
|
398 |
+
log("Using Validation Metrics: %s" % str(list(metrics.keys())))
|
399 |
+
log("Supported Metrics: %s" % supported_metrics)
|
400 |
+
|
401 |
+
if val_set_size is None:
|
402 |
+
if len(metrics) == 0:
|
403 |
+
val_set_size = 1000
|
404 |
+
else:
|
405 |
+
val_set_size = 100
|
406 |
+
log("Auto set val_set_size %s" % val_set_size)
|
407 |
+
elif val_set_size < 1.0 and val_set_size != 0:
|
408 |
+
raise RuntimeError("Fractional validation size not supported.")
|
409 |
+
|
410 |
+
if valid_path:
|
411 |
+
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
412 |
+
else:
|
413 |
+
if "json" in data_path:
|
414 |
+
data = load_dataset("json", data_files={"train": data_path})
|
415 |
+
else:
|
416 |
+
data = load_dataset(data_path)
|
417 |
+
data = data.rename_columns(data_col_dict or {})
|
418 |
+
|
419 |
+
valid_data = None
|
420 |
+
train_data_mix_in = None
|
421 |
+
valid_data_mix_in = None
|
422 |
+
|
423 |
+
if data_mix_in_path and data_mix_in_factor > 0:
|
424 |
+
# get mix-in training/validation data - to keep model "sane"
|
425 |
+
num_rows = data["train"].num_rows
|
426 |
+
log("Loading mix-in dataset: %s" % data_mix_in_path)
|
427 |
+
if "json" in data_mix_in_path:
|
428 |
+
data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
|
429 |
+
else:
|
430 |
+
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
431 |
+
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
432 |
+
|
433 |
+
# only get as much as we need to balance
|
434 |
+
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
435 |
+
train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
|
436 |
+
mixin_small = data_mix_in.train_test_split(
|
437 |
+
test_size=train_size + valid_size,
|
438 |
+
shuffle=True, seed=np.random.randint(10000),
|
439 |
+
)["test"]
|
440 |
+
if valid_size:
|
441 |
+
mixin_train_test = mixin_small.train_test_split(
|
442 |
+
test_size=valid_size, shuffle=False,
|
443 |
+
)
|
444 |
+
train_data_mix_in = mixin_train_test["train"]
|
445 |
+
valid_data_mix_in = mixin_train_test["test"]
|
446 |
+
else:
|
447 |
+
train_data_mix_in = mixin_small
|
448 |
+
|
449 |
+
if "prompt_type" not in train_data_mix_in.column_names:
|
450 |
+
train_data_mix_in = train_data_mix_in.add_column(
|
451 |
+
"prompt_type",
|
452 |
+
[data_mix_in_prompt_type] * train_data_mix_in.num_rows,
|
453 |
+
)
|
454 |
+
log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
|
455 |
+
if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
|
456 |
+
valid_data_mix_in = valid_data_mix_in.add_column(
|
457 |
+
"prompt_type",
|
458 |
+
[data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
|
459 |
+
)
|
460 |
+
log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
|
461 |
+
log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
|
462 |
+
|
463 |
+
# get our own training/validation data - for fine-tuning
|
464 |
+
if val_set_size > 0 and not valid_path and not data_mix_in_path:
|
465 |
+
# create valid split from train
|
466 |
+
train_val = data["train"].train_test_split(
|
467 |
+
test_size=val_set_size, shuffle=True, seed=42
|
468 |
+
)
|
469 |
+
train_data = train_val["train"]
|
470 |
+
valid_data = train_val["test"]
|
471 |
+
else:
|
472 |
+
train_data = data["train"]
|
473 |
+
if valid_path:
|
474 |
+
# use given valid split, has priority over data_mix_in_path
|
475 |
+
valid_data = data["valid"]
|
476 |
+
if "prompt_type" not in train_data.column_names:
|
477 |
+
train_data = train_data.add_column(
|
478 |
+
"prompt_type",
|
479 |
+
[prompt_type] * train_data.num_rows,
|
480 |
+
)
|
481 |
+
log("Added prompt type %s to training data" % prompt_type)
|
482 |
+
if valid_data and "prompt_type" not in valid_data.column_names:
|
483 |
+
valid_data = valid_data.add_column(
|
484 |
+
"prompt_type",
|
485 |
+
[prompt_type] * valid_data.num_rows,
|
486 |
+
)
|
487 |
+
log("Added prompt type %s to validation data" % prompt_type)
|
488 |
+
|
489 |
+
assert train_data is not None
|
490 |
+
|
491 |
+
# shuffle and tokenize data
|
492 |
+
if train_data_mix_in:
|
493 |
+
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
494 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
495 |
+
train_set_size = len(train_data)
|
496 |
+
|
497 |
+
if valid_data and valid_data_mix_in:
|
498 |
+
valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
|
499 |
+
elif valid_data_mix_in:
|
500 |
+
valid_data = valid_data_mix_in
|
501 |
+
|
502 |
+
if valid_data:
|
503 |
+
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
504 |
+
val_set_size = len(valid_data)
|
505 |
+
else:
|
506 |
+
val_set_size = 0
|
507 |
+
log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
|
508 |
+
sample_row_dict = train_data[:1]
|
509 |
+
del sample_row_dict['input_ids']
|
510 |
+
del sample_row_dict['attention_mask']
|
511 |
+
del sample_row_dict['labels']
|
512 |
+
log("Sample input: %s" % sample_row_dict)
|
513 |
+
|
514 |
+
if neptune_run:
|
515 |
+
neptune_callback = NeptuneCallback(run=neptune_run)
|
516 |
+
callbacks = [neptune_callback]
|
517 |
+
else:
|
518 |
+
from transformers.integrations import TensorBoardCallback, is_tensorboard_available
|
519 |
+
if is_tensorboard_available:
|
520 |
+
# tensorboard --logdir=runs/
|
521 |
+
from torch.utils.tensorboard import SummaryWriter
|
522 |
+
tb_writer = SummaryWriter()
|
523 |
+
callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
|
524 |
+
else:
|
525 |
+
callbacks = []
|
526 |
+
|
527 |
+
expected_steps = (train_set_size * num_epochs) // batch_size
|
528 |
+
if eval_steps is None and eval_epochs is None:
|
529 |
+
# 20 evaluations for a run
|
530 |
+
eval_steps = max(1, int(expected_steps / 20))
|
531 |
+
log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
|
532 |
+
elif eval_steps is None and eval_epochs is not None:
|
533 |
+
eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
|
534 |
+
log("Auto converted eval_epochs=%s to eval_steps %s"
|
535 |
+
" out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
|
536 |
+
if save_steps is None:
|
537 |
+
save_steps = eval_steps
|
538 |
+
log("Auto step save_steps to %s" % save_steps)
|
539 |
+
elif save_steps > eval_steps:
|
540 |
+
# save steps must be round multiple of eval_steps
|
541 |
+
save_steps0 = save_steps
|
542 |
+
save_steps = max(1, (save_steps//eval_steps)) * eval_steps
|
543 |
+
if save_steps0 != save_steps:
|
544 |
+
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
545 |
+
|
546 |
+
def compute_metrics(eval_preds):
|
547 |
+
# e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
|
548 |
+
inputs = eval_preds.inputs
|
549 |
+
label_ids = eval_preds.label_ids
|
550 |
+
predictions = eval_preds.predictions
|
551 |
+
|
552 |
+
#inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
553 |
+
#decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
554 |
+
#decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
555 |
+
|
556 |
+
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
557 |
+
# tokenizer behavior like generate time
|
558 |
+
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
559 |
+
clean_up_tokenization_spaces=True)
|
560 |
+
decoded_labels = [pred.strip() for pred in decoded_labels]
|
561 |
+
|
562 |
+
predictions = np.argmax(predictions, -1)
|
563 |
+
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
564 |
+
# tokenizer behavior like generate time
|
565 |
+
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
566 |
+
clean_up_tokenization_spaces=True)
|
567 |
+
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
568 |
+
|
569 |
+
result = {}
|
570 |
+
for metric in metrics.values():
|
571 |
+
result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
572 |
+
# get rid of lists, for precision etc., for now
|
573 |
+
numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
|
574 |
+
result.update(numeric_results)
|
575 |
+
return result
|
576 |
+
|
577 |
+
# the callback that computes metrics of interest
|
578 |
+
if val_metrics:
|
579 |
+
trainer_kwargs = dict(compute_metrics=compute_metrics)
|
580 |
+
else:
|
581 |
+
trainer_kwargs = dict()
|
582 |
+
|
583 |
+
trainer = transformers.Trainer(
|
584 |
+
model=model,
|
585 |
+
tokenizer=tokenizer,
|
586 |
+
train_dataset=train_data,
|
587 |
+
eval_dataset=valid_data,
|
588 |
+
# NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
|
589 |
+
args=transformers.Seq2SeqTrainingArguments(
|
590 |
+
per_device_train_batch_size=micro_batch_size,
|
591 |
+
per_device_eval_batch_size=1,
|
592 |
+
eval_accumulation_steps=10,
|
593 |
+
# predict_with_generate=True, # SEQ2SEQ only
|
594 |
+
include_inputs_for_metrics=True,
|
595 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
596 |
+
warmup_steps=warmup_steps,
|
597 |
+
num_train_epochs=num_epochs,
|
598 |
+
learning_rate=learning_rate,
|
599 |
+
gradient_checkpointing=gradient_checkpointing,
|
600 |
+
fp16=fp16,
|
601 |
+
# cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
|
602 |
+
optim="adamw_torch", # consider "adafactor" to save memory
|
603 |
+
logging_steps=logging_steps,
|
604 |
+
logging_strategy="steps",
|
605 |
+
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
606 |
+
save_strategy="steps",
|
607 |
+
eval_steps=eval_steps if val_set_size > 0 else None,
|
608 |
+
save_steps=save_steps,
|
609 |
+
output_dir=output_dir,
|
610 |
+
save_total_limit=3,
|
611 |
+
load_best_model_at_end=True if val_set_size > 0 else False,
|
612 |
+
ddp_find_unused_parameters=False if ddp else None,
|
613 |
+
group_by_length=group_by_length,
|
614 |
+
#fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
615 |
+
#fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
616 |
+
report_to='tensorboard' if not neptune_run else 'neptune',
|
617 |
+
),
|
618 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
619 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
620 |
+
),
|
621 |
+
callbacks=callbacks,
|
622 |
+
**trainer_kwargs,
|
623 |
+
)
|
624 |
+
model.config.use_cache = False
|
625 |
+
|
626 |
+
old_state_dict = model.state_dict
|
627 |
+
model.state_dict = (
|
628 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
629 |
+
).__get__(model, type(model))
|
630 |
+
|
631 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
632 |
+
model = torch.compile(model)
|
633 |
+
# WIP (not generally replacing layers until pytorch 2.1)
|
634 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
635 |
+
|
636 |
+
if gpus > 1 and not ddp:
|
637 |
+
assert trainer.is_model_parallel
|
638 |
+
else:
|
639 |
+
assert not trainer.is_model_parallel
|
640 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
641 |
+
|
642 |
+
model.save_pretrained(output_dir)
|
643 |
+
|
644 |
+
log("\n If there's a warning about missing keys above, please disregard :)")
|
645 |
+
|
646 |
+
|
647 |
+
def get_loaders(llama_type, model_name, reward_type):
|
648 |
+
# NOTE: Some models need specific new prompt_type
|
649 |
+
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
650 |
+
if llama_type:
|
651 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
652 |
+
model_loader = LlamaForCausalLM
|
653 |
+
tokenizer_loader = LlamaTokenizer
|
654 |
+
elif 'gpt2' in model_name.lower():
|
655 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
656 |
+
return GPT2LMHeadModel, GPT2Tokenizer
|
657 |
+
elif 'mbart-' in model_name.lower():
|
658 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
659 |
+
return MBartForConditionalGeneration, MBart50TokenizerFast
|
660 |
+
elif 't5' == model_name.lower() or \
|
661 |
+
't5-' in model_name.lower() or \
|
662 |
+
'flan-' in model_name.lower():
|
663 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
664 |
+
return T5ForConditionalGeneration, AutoTokenizer
|
665 |
+
elif 'bigbird' in model_name:
|
666 |
+
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
667 |
+
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
668 |
+
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
669 |
+
from transformers import pipeline
|
670 |
+
return pipeline, "summarization"
|
671 |
+
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
672 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
673 |
+
return AutoModelForSequenceClassification, AutoTokenizer
|
674 |
+
else:
|
675 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
676 |
+
model_loader = AutoModelForCausalLM
|
677 |
+
tokenizer_loader = AutoTokenizer
|
678 |
+
return model_loader, tokenizer_loader
|
679 |
+
|
680 |
+
|
681 |
+
def get_githash():
|
682 |
+
try:
|
683 |
+
githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
|
684 |
+
except:
|
685 |
+
githash = ''
|
686 |
+
return githash
|
687 |
+
|
688 |
+
|
689 |
+
def copy_code(run_id):
|
690 |
+
"""
|
691 |
+
copy code to track changes
|
692 |
+
:param run_id:
|
693 |
+
:return:
|
694 |
+
"""
|
695 |
+
rnd_num = str(random.randint(0, 2 ** 31))
|
696 |
+
run_id = 'run_' + str(run_id)
|
697 |
+
os.makedirs(run_id, exist_ok=True)
|
698 |
+
me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
|
699 |
+
me_file = os.path.basename(__file__)
|
700 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash())
|
701 |
+
if os.path.isfile(new_me):
|
702 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
|
703 |
+
shutil.copy(me_full, new_me)
|
704 |
+
else:
|
705 |
+
shutil.copy(me_full, new_me)
|
706 |
+
|
707 |
+
|
708 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
709 |
+
if prompt_type in [-1, "-1", "plain"]:
|
710 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
711 |
+
terminate_response = []
|
712 |
+
elif prompt_type == 'simple_instruct':
|
713 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
714 |
+
terminate_response = []
|
715 |
+
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
716 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
717 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
718 |
+
|
719 |
+
PreInstruct = """
|
720 |
+
### Instruction:
|
721 |
+
"""
|
722 |
+
|
723 |
+
PreInput = """
|
724 |
+
### Input:
|
725 |
+
"""
|
726 |
+
|
727 |
+
PreResponse = """
|
728 |
+
### Response:
|
729 |
+
"""
|
730 |
+
if prompt_type in [7, "7", "instruct_with_end"]:
|
731 |
+
terminate_response = ['### End']
|
732 |
+
else:
|
733 |
+
terminate_response = None
|
734 |
+
elif prompt_type in [1, "1", "quality"]:
|
735 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
736 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
737 |
+
|
738 |
+
PreInstruct = """
|
739 |
+
### Instruction:
|
740 |
+
"""
|
741 |
+
|
742 |
+
PreInput = """
|
743 |
+
### Input:
|
744 |
+
"""
|
745 |
+
|
746 |
+
PreResponse = """
|
747 |
+
### Response:
|
748 |
+
"""
|
749 |
+
terminate_response = None
|
750 |
+
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
751 |
+
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
752 |
+
preprompt = ''
|
753 |
+
else:
|
754 |
+
cur_date = time.strftime('%Y-%m-%d')
|
755 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
756 |
+
|
757 |
+
PRE_PROMPT = """\
|
758 |
+
Current Date: {}
|
759 |
+
Current Time: {}
|
760 |
+
|
761 |
+
"""
|
762 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
763 |
+
start = human
|
764 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
765 |
+
|
766 |
+
PreInstruct = ""
|
767 |
+
|
768 |
+
PreInput = None
|
769 |
+
|
770 |
+
PreResponse = bot
|
771 |
+
|
772 |
+
terminate_response = [start, PreResponse]
|
773 |
+
elif prompt_type in [3, "3", "dai_faq"]:
|
774 |
+
promptA = ''
|
775 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
776 |
+
|
777 |
+
PreInstruct = """
|
778 |
+
### Driverless AI frequently asked question:
|
779 |
+
"""
|
780 |
+
|
781 |
+
PreInput = None
|
782 |
+
|
783 |
+
PreResponse = """
|
784 |
+
### Driverless AI documentation answer:
|
785 |
+
"""
|
786 |
+
terminate_response = ['\n\n']
|
787 |
+
elif prompt_type in [5, "5", "summarize"]:
|
788 |
+
promptA = promptB = PreInput = ''
|
789 |
+
PreInstruct = '## Main Text\n\n'
|
790 |
+
PreResponse = '\n\n## Summary\n\n'
|
791 |
+
terminate_response = None
|
792 |
+
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
793 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
794 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
795 |
+
|
796 |
+
PreInstruct = """
|
797 |
+
### Human:
|
798 |
+
"""
|
799 |
+
|
800 |
+
PreInput = None
|
801 |
+
|
802 |
+
PreResponse = """
|
803 |
+
### Assistant:
|
804 |
+
"""
|
805 |
+
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
806 |
+
else:
|
807 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
808 |
+
|
809 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
|
810 |
+
|
811 |
+
|
812 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
813 |
+
context = data_point.get('context')
|
814 |
+
if context is None:
|
815 |
+
context = ''
|
816 |
+
instruction = data_point.get('instruction')
|
817 |
+
input = data_point.get('input')
|
818 |
+
output = data_point.get('output')
|
819 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
820 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
821 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
|
822 |
+
|
823 |
+
prompt = context
|
824 |
+
|
825 |
+
if input and promptA:
|
826 |
+
prompt += f"""{promptA}"""
|
827 |
+
elif promptB:
|
828 |
+
prompt += f"""{promptB}"""
|
829 |
+
|
830 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
831 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
832 |
+
prompt = inject_newline(prompt_type, prompt)
|
833 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
834 |
+
prompt += f"""{PreInput}{instruction}
|
835 |
+
{input}"""
|
836 |
+
prompt = inject_newline(prompt_type, prompt)
|
837 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
838 |
+
prompt += f"""{PreInstruct}{instruction}
|
839 |
+
{input}"""
|
840 |
+
prompt = inject_newline(prompt_type, prompt)
|
841 |
+
elif instruction and PreInstruct is not None:
|
842 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
843 |
+
prompt = inject_newline(prompt_type, prompt)
|
844 |
+
elif input and PreInput is not None:
|
845 |
+
prompt += f"""{PreInput}{input}"""
|
846 |
+
prompt = inject_newline(prompt_type, prompt)
|
847 |
+
elif input and instruction and PreInput is not None:
|
848 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
849 |
+
prompt = inject_newline(prompt_type, prompt)
|
850 |
+
elif input and instruction and PreInstruct is not None:
|
851 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
852 |
+
prompt = inject_newline(prompt_type, prompt)
|
853 |
+
elif input and instruction:
|
854 |
+
# i.e. for simple_instruct
|
855 |
+
prompt += f"""{instruction}: {input}"""
|
856 |
+
prompt = inject_newline(prompt_type, prompt)
|
857 |
+
elif input:
|
858 |
+
prompt += f"""{input}"""
|
859 |
+
prompt = inject_newline(prompt_type, prompt)
|
860 |
+
elif instruction:
|
861 |
+
prompt += f"""{instruction}"""
|
862 |
+
prompt = inject_newline(prompt_type, prompt)
|
863 |
+
|
864 |
+
if PreResponse is not None:
|
865 |
+
prompt += f"""{PreResponse}"""
|
866 |
+
pre_response = PreResponse # Don't use strip
|
867 |
+
else:
|
868 |
+
pre_response = ''
|
869 |
+
|
870 |
+
if output:
|
871 |
+
prompt += f"""{output}"""
|
872 |
+
|
873 |
+
return prompt, pre_response, terminate_response
|
874 |
+
|
875 |
+
|
876 |
+
def inject_newline(prompt_type, prompt):
|
877 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
878 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
879 |
+
prompt += '\n'
|
880 |
+
return prompt
|
881 |
+
|
882 |
+
|
883 |
+
example_data_point0 = dict(instruction="Summarize",
|
884 |
+
input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
|
885 |
+
output="Ducks eat and swim at the lake.")
|
886 |
+
|
887 |
+
example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
|
888 |
+
output="Einstein.")
|
889 |
+
|
890 |
+
example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
|
891 |
+
output="Einstein.")
|
892 |
+
|
893 |
+
example_data_points = [example_data_point0, example_data_point1, example_data_point2]
|
894 |
+
|
895 |
+
|
896 |
+
def test_train_prompt(prompt_type='instruct', data_point=0):
|
897 |
+
example_data_point = example_data_points[data_point]
|
898 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
899 |
+
|
900 |
+
|
901 |
+
def test_debug():
|
902 |
+
fire.Fire(train)
|
903 |
+
|
904 |
+
|
905 |
+
if __name__ == "__main__":
|
906 |
+
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
907 |
+
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
908 |
+
log(f"""
|
909 |
+
Example runs on 4 GPUs:
|
910 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
|
911 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
|
912 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
|
913 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
|
914 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
|
915 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
|
916 |
+
|
917 |
+
All metrics:
|
918 |
+
CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
|
919 |
+
|
920 |
+
# Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
|
921 |
+
rippa>
|
922 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
|
923 |
+
ova>
|
924 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
|
925 |
+
timemachine>
|
926 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
|
927 |
+
|
928 |
+
""", flush=True)
|
929 |
+
|
930 |
+
if os.environ.get("LOCAL_RANK") is None:
|
931 |
+
# then not using torchrun, so can't do distributed, ensure CVD set
|
932 |
+
assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
933 |
+
|
934 |
+
fire.Fire(train)
|
h2o-logo.svg
ADDED
|
prompter.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from finetune import generate_prompt
|
2 |
+
|
3 |
+
|
4 |
+
class Prompter(object):
|
5 |
+
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
6 |
+
allowed_repeat_line_length=10):
|
7 |
+
self.prompt_type = prompt_type
|
8 |
+
data_point = dict(instruction='', input='', output='')
|
9 |
+
_, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
|
10 |
+
self.debug = debug
|
11 |
+
self.chat = chat
|
12 |
+
self.stream_output = stream_output
|
13 |
+
self.repeat_penalty = repeat_penalty
|
14 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
15 |
+
|
16 |
+
def generate_prompt(self, data_point):
|
17 |
+
reduced = False
|
18 |
+
prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
19 |
+
if self.debug:
|
20 |
+
print("prompt: ", prompt, flush=True)
|
21 |
+
self.prompt = prompt
|
22 |
+
return prompt
|
23 |
+
|
24 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
|
25 |
+
if isinstance(outputs, str):
|
26 |
+
outputs = [outputs]
|
27 |
+
if self.debug:
|
28 |
+
print("output: ", '\n\n'.join(outputs), flush=True)
|
29 |
+
if prompt is not None:
|
30 |
+
self.prompt = prompt
|
31 |
+
|
32 |
+
def clean_response(response):
|
33 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
|
34 |
+
for word in meaningless_words:
|
35 |
+
response = response.replace(word, "")
|
36 |
+
if sanitize_bot_response:
|
37 |
+
from better_profanity import profanity
|
38 |
+
response = profanity.censor(response)
|
39 |
+
response = response.strip("\n")
|
40 |
+
return response
|
41 |
+
|
42 |
+
def clean_repeats(response):
|
43 |
+
lines = response.split('\n')
|
44 |
+
new_lines = []
|
45 |
+
[new_lines.append(line) for line in lines if
|
46 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
47 |
+
if self.debug and len(lines) != len(new_lines):
|
48 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
49 |
+
response = '\n'.join(new_lines)
|
50 |
+
return response
|
51 |
+
|
52 |
+
multi_output = len(outputs) > 1
|
53 |
+
|
54 |
+
for oi, output in enumerate(outputs):
|
55 |
+
if self.prompt_type in [0, '0', 'plain']:
|
56 |
+
output = clean_response(output)
|
57 |
+
else:
|
58 |
+
# find first instance of prereponse
|
59 |
+
# prompt sometimes has odd characters, that mutate length,
|
60 |
+
# so can't go by length alone
|
61 |
+
if self.pre_response:
|
62 |
+
outputi = output.find(prompt)
|
63 |
+
if outputi >= 0:
|
64 |
+
output = output[outputi + len(prompt):]
|
65 |
+
allow_terminate = True
|
66 |
+
else:
|
67 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
68 |
+
output = output[len(prompt) - len(self.pre_response):]
|
69 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
70 |
+
if self.pre_response in output:
|
71 |
+
output = output.split(self.pre_response)[1]
|
72 |
+
allow_terminate = True
|
73 |
+
else:
|
74 |
+
print("Failure of parsing: %s" % output, flush=True)
|
75 |
+
allow_terminate = False
|
76 |
+
else:
|
77 |
+
allow_terminate = True
|
78 |
+
output = output[len(prompt):]
|
79 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
80 |
+
output = clean_response(output).strip()
|
81 |
+
if self.repeat_penalty:
|
82 |
+
output = clean_repeats(output).strip()
|
83 |
+
if self.terminate_response and allow_terminate:
|
84 |
+
finds = []
|
85 |
+
for term in self.terminate_response:
|
86 |
+
finds.append(output.find(term))
|
87 |
+
finds = [x for x in finds if x >= 0]
|
88 |
+
if len(finds) > 0:
|
89 |
+
termi = finds[0]
|
90 |
+
output = output[:termi].strip()
|
91 |
+
else:
|
92 |
+
output = output.strip()
|
93 |
+
else:
|
94 |
+
output = output.strip()
|
95 |
+
if multi_output:
|
96 |
+
# prefix with output counter
|
97 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
98 |
+
if oi > 0:
|
99 |
+
# post fix outputs with seperator
|
100 |
+
output += '\n'
|
101 |
+
outputs[oi] = output
|
102 |
+
# join all outputs, only one extra new line between outputs
|
103 |
+
output = '\n'.join(outputs)
|
104 |
+
if self.debug:
|
105 |
+
print("outputclean: ", '\n\n'.join(outputs), flush=True)
|
106 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.10.1
|
3 |
+
sentencepiece==0.1.97
|
4 |
+
accelerate==0.18.0
|
5 |
+
gradio==3.27.0
|
6 |
+
huggingface_hub==0.13.4
|
7 |
+
appdirs==1.4.4
|
8 |
+
fire==0.5.0
|
9 |
+
docutils==0.19
|
10 |
+
torch==2.0.0
|
11 |
+
evaluate==0.4.0
|
12 |
+
rouge_score==0.1.2
|
13 |
+
sacrebleu==2.3.1
|
14 |
+
scikit-learn==1.2.2
|
15 |
+
alt-profanity-check==1.2.2
|
16 |
+
better-profanity==0.6.1
|
17 |
+
numpy==1.24.2
|
18 |
+
pandas==1.5.3
|
19 |
+
matplotlib==3.7.1
|
20 |
+
loralib==0.1.1
|
21 |
+
bitsandbytes==0.38.1
|
22 |
+
git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
|
23 |
+
transformers==4.28.1
|
24 |
+
tokenizers==0.13.3
|
25 |
+
|
26 |
+
# optional for generate
|
27 |
+
pynvml==11.5.0
|
28 |
+
psutil==5.9.4
|
29 |
+
|
30 |
+
# optional for finetune
|
31 |
+
tensorboard==2.12.1
|
32 |
+
neptune==1.1.1
|
33 |
+
|
34 |
+
# for gradio client
|
35 |
+
gradio_client==0.1.3
|
36 |
+
beautifulsoup4==4.12.2
|
37 |
+
markdown==3.4.1
|
38 |
+
|
39 |
+
# data and testing
|
40 |
+
pytest==7.2.2
|
41 |
+
pytest-xdist==3.2.1
|
42 |
+
nltk==3.8.1
|
43 |
+
textstat==0.7.3
|
44 |
+
pandoc==2.3
|
45 |
+
pypandoc==1.11
|
46 |
+
openpyxl==3.1.2
|
47 |
+
lm_dataformat==0.0.20
|
48 |
+
bioc==2.0
|
stopping.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from queue import Queue
|
3 |
+
from threading import Thread
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import StoppingCriteria
|
8 |
+
|
9 |
+
|
10 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
11 |
+
|
12 |
+
def __init__(self, stops=[], encounters=[]):
|
13 |
+
super().__init__()
|
14 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
15 |
+
self.encounters = encounters
|
16 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
17 |
+
self.num_stops = [0] * len(stops)
|
18 |
+
|
19 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
20 |
+
for stopi, stop in enumerate(self.stops):
|
21 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
22 |
+
self.num_stops[stopi] += 1
|
23 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
24 |
+
return True
|
25 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
26 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
class Stream(StoppingCriteria):
|
31 |
+
"""
|
32 |
+
This class can be used to callback during generation. Keep
|
33 |
+
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
func (`callable`):
|
37 |
+
A callable function to apply on first input in list every iteration of generation
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, func=None):
|
41 |
+
self.func = func
|
42 |
+
|
43 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
44 |
+
if self.func is not None:
|
45 |
+
# only consume first of multiple responses
|
46 |
+
self.func(input_ids[0])
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
class CallbackToGenerator(collections.abc.Generator):
|
51 |
+
"""
|
52 |
+
A generator wrapper for a function that invokes a callback multiple times.
|
53 |
+
|
54 |
+
Calling `send` on the generator emits a value from one callback, and returns
|
55 |
+
the next.
|
56 |
+
|
57 |
+
Note this starts a background thread
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, func, *args, callback=None, **kwargs):
|
61 |
+
self.func = func
|
62 |
+
self.args = args
|
63 |
+
self.kwargs = kwargs
|
64 |
+
self.callback = callback
|
65 |
+
|
66 |
+
self._ready_queue = Queue(1)
|
67 |
+
self._done_queue = Queue(1)
|
68 |
+
self._done_holder = [False]
|
69 |
+
|
70 |
+
# local to avoid reference cycles
|
71 |
+
ready_queue = self._ready_queue
|
72 |
+
done_queue = self._done_queue
|
73 |
+
done_holder = self._done_holder
|
74 |
+
|
75 |
+
def val_callback(value):
|
76 |
+
done_queue.put((False, value))
|
77 |
+
cmd, val = ready_queue.get()
|
78 |
+
if cmd == 'send':
|
79 |
+
return val
|
80 |
+
elif cmd == 'throw':
|
81 |
+
raise val
|
82 |
+
else:
|
83 |
+
assert False # pragma: no cover
|
84 |
+
|
85 |
+
def thread_func():
|
86 |
+
while True:
|
87 |
+
cmd, val = ready_queue.get()
|
88 |
+
if cmd == 'send' and val is not None:
|
89 |
+
done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
|
90 |
+
continue
|
91 |
+
break
|
92 |
+
try:
|
93 |
+
if cmd == 'throw':
|
94 |
+
raise val
|
95 |
+
ret = func(callback=val_callback, **self.kwargs)
|
96 |
+
raise StopIteration(ret) if ret is not None else StopIteration
|
97 |
+
except BaseException as e:
|
98 |
+
done_holder[0] = True
|
99 |
+
done_queue.put((True, e))
|
100 |
+
|
101 |
+
self._thread = Thread(target=thread_func)
|
102 |
+
self._thread.start()
|
103 |
+
|
104 |
+
def _put(self, *args):
|
105 |
+
if self._done_holder[0]:
|
106 |
+
raise StopIteration
|
107 |
+
self._ready_queue.put(args)
|
108 |
+
is_exception, val = self._done_queue.get()
|
109 |
+
if is_exception:
|
110 |
+
try:
|
111 |
+
raise val
|
112 |
+
finally:
|
113 |
+
# prevent val's traceback containing a reference cycle
|
114 |
+
del val
|
115 |
+
else:
|
116 |
+
return val
|
117 |
+
|
118 |
+
def send(self, value):
|
119 |
+
return self._put('send', value)
|
120 |
+
|
121 |
+
def throw(self, exc):
|
122 |
+
return self._put('throw', exc)
|
123 |
+
|
124 |
+
def close(self):
|
125 |
+
try:
|
126 |
+
self.throw(GeneratorExit)
|
127 |
+
except StopIteration:
|
128 |
+
self._thread.join()
|
129 |
+
except GeneratorExit:
|
130 |
+
self._thread.join()
|
131 |
+
except BaseException:
|
132 |
+
self._thread.join()
|
133 |
+
raise
|
134 |
+
else:
|
135 |
+
# yielded again, can't clean up the thread
|
136 |
+
raise RuntimeError('Task with callback ignored GeneratorExit')
|
137 |
+
|
138 |
+
def __del__(self):
|
139 |
+
self.close()
|
utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
+
import filelock
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def set_seed(seed: int):
|
15 |
+
"""
|
16 |
+
Sets the seed of the entire notebook so results are the same every time we run.
|
17 |
+
This is for REPRODUCIBILITY.
|
18 |
+
"""
|
19 |
+
np.random.seed(seed)
|
20 |
+
random_state = np.random.RandomState(seed)
|
21 |
+
random.seed(seed)
|
22 |
+
torch.manual_seed(seed)
|
23 |
+
torch.cuda.manual_seed(seed)
|
24 |
+
torch.backends.cudnn.deterministic = True
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
27 |
+
return random_state
|
28 |
+
|
29 |
+
|
30 |
+
def flatten_list(lis):
|
31 |
+
"""Given a list, possibly nested to any level, return it flattened."""
|
32 |
+
new_lis = []
|
33 |
+
for item in lis:
|
34 |
+
if type(item) == type([]):
|
35 |
+
new_lis.extend(flatten_list(item))
|
36 |
+
else:
|
37 |
+
new_lis.append(item)
|
38 |
+
return new_lis
|
39 |
+
|
40 |
+
|
41 |
+
def clear_torch_cache():
|
42 |
+
if torch.cuda.is_available:
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
torch.cuda.ipc_collect()
|
45 |
+
gc.collect()
|
46 |
+
|
47 |
+
|
48 |
+
def system_info():
|
49 |
+
import psutil
|
50 |
+
|
51 |
+
system = {}
|
52 |
+
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
53 |
+
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
54 |
+
temps = psutil.sensors_temperatures(fahrenheit=False)
|
55 |
+
if 'coretemp' in temps:
|
56 |
+
coretemp = temps['coretemp']
|
57 |
+
temp_dict = {k.label: k.current for k in coretemp}
|
58 |
+
for k, v in temp_dict.items():
|
59 |
+
system['CPU_C/%s' % k] = v
|
60 |
+
|
61 |
+
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
62 |
+
from pynvml.smi import nvidia_smi
|
63 |
+
nvsmi = nvidia_smi.getInstance()
|
64 |
+
|
65 |
+
gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
|
66 |
+
enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
|
67 |
+
for k, v in gpu_power_dict.items():
|
68 |
+
system['GPU_W/%s' % k] = v
|
69 |
+
|
70 |
+
gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
|
71 |
+
enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
|
72 |
+
for k, v in gpu_temp_dict.items():
|
73 |
+
system['GPU_C/%s' % k] = v
|
74 |
+
|
75 |
+
gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
|
76 |
+
enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
|
77 |
+
gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
|
78 |
+
enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
|
79 |
+
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
80 |
+
for k, v in gpu_memory_frac_dict.items():
|
81 |
+
system[f'GPU_M/%s' % k] = v
|
82 |
+
|
83 |
+
return system
|
84 |
+
|
85 |
+
|
86 |
+
def system_info_print():
|
87 |
+
try:
|
88 |
+
df = pd.DataFrame.from_dict(system_info(), orient='index')
|
89 |
+
# avoid slamming GPUs
|
90 |
+
time.sleep(1)
|
91 |
+
return df.to_markdown()
|
92 |
+
except Exception as e:
|
93 |
+
return "Error: %s" % str(e)
|
94 |
+
|
95 |
+
|
96 |
+
def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
97 |
+
try:
|
98 |
+
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
99 |
+
except Exception as e:
|
100 |
+
traceback.print_exc()
|
101 |
+
print('Exception in zipping: %s' % str(e))
|
102 |
+
|
103 |
+
|
104 |
+
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
105 |
+
if zip_file is None:
|
106 |
+
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
107 |
+
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
108 |
+
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
109 |
+
assert root_dirs is not None
|
110 |
+
|
111 |
+
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
112 |
+
for root_dir in root_dirs:
|
113 |
+
if root_dir is None:
|
114 |
+
continue
|
115 |
+
for root, d, files in os.walk(root_dir):
|
116 |
+
for file in files:
|
117 |
+
file_to_archive = os.path.join(root, file)
|
118 |
+
assert os.path.exists(file_to_archive)
|
119 |
+
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
120 |
+
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
121 |
+
return zip_file
|
122 |
+
|
123 |
+
|
124 |
+
def save_generate_output(output=None, base_model=None, save_dir=None):
|
125 |
+
try:
|
126 |
+
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
|
127 |
+
except Exception as e:
|
128 |
+
traceback.print_exc()
|
129 |
+
print('Exception in saving: %s' % str(e))
|
130 |
+
|
131 |
+
|
132 |
+
def _save_generate_output(output=None, base_model=None, save_dir=None):
|
133 |
+
"""
|
134 |
+
Save conversation to .json, row by row.
|
135 |
+
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
136 |
+
Appends if file exists
|
137 |
+
"""
|
138 |
+
assert save_dir, "save_dir must be provided"
|
139 |
+
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
140 |
+
raise RuntimeError("save_dir already exists and is not a directory!")
|
141 |
+
os.makedirs(save_dir, exist_ok=True)
|
142 |
+
import json
|
143 |
+
if output[-10:] == '\n\n<human>:':
|
144 |
+
# remove trailing <human>:
|
145 |
+
output = output[:-10]
|
146 |
+
with filelock.FileLock("save_dir.lock"):
|
147 |
+
# lock logging in case have concurrency
|
148 |
+
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
149 |
+
# just add [ at start, and ] at end, and have proper JSON dataset
|
150 |
+
f.write(
|
151 |
+
" " + json.dumps(
|
152 |
+
dict(text=output, time=time.ctime(), base_model=base_model)
|
153 |
+
) + ",\n"
|
154 |
+
)
|