vs4vijay pseudotensor commited on
Commit
b6bff08
·
0 Parent(s):

Duplicate from h2oai/h2ogpt-chatbot

Browse files

Co-authored-by: Jonathan McKinney <pseudotensor@users.noreply.huggingface.co>

Files changed (11) hide show
  1. .gitattributes +34 -0
  2. LICENSE +201 -0
  3. README.md +14 -0
  4. app.py +1959 -0
  5. client_test.py +93 -0
  6. finetune.py +934 -0
  7. h2o-logo.svg +1 -0
  8. prompter.py +106 -0
  9. requirements.txt +48 -0
  10. stopping.py +139 -0
  11. utils.py +154 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: H2ogpt Chatbot
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: h2oai/h2ogpt-chatbot
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import sys
4
+ import os
5
+ import traceback
6
+ import typing
7
+ from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
8
+
9
+ SEED = 1236
10
+ set_seed(SEED)
11
+
12
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
13
+ from typing import Union
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ import fire
18
+ import torch
19
+ from peft import PeftModel
20
+ from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
21
+ from accelerate import init_empty_weights, infer_auto_device_map
22
+
23
+ from prompter import Prompter
24
+
25
+ from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
26
+ human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
27
+ from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
28
+
29
+ is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
30
+ is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
31
+ is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
32
+ is_low_mem = is_hf # assumes run on 24GB consumer GPU
33
+ admin_pass = os.getenv("ADMIN_PASS")
34
+ # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
35
+ raise_generate_gpu_exceptions = True
36
+
37
+ eval_extra_columns = ['prompt', 'response', 'score']
38
+
39
+ def main(
40
+ load_8bit: bool = False,
41
+ load_half: bool = True,
42
+ infer_devices: bool = True,
43
+ base_model: str = '',
44
+ tokenizer_base_model: str = '',
45
+ lora_weights: str = "",
46
+ gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
47
+
48
+ prompt_type: Union[int, str] = None,
49
+ # input to generation
50
+ temperature: float = None,
51
+ top_p: float = None,
52
+ top_k: int = None,
53
+ num_beams: int = None,
54
+ repetition_penalty: float = None,
55
+ num_return_sequences: int = None,
56
+ do_sample: bool = None,
57
+ max_new_tokens: int = None,
58
+ min_new_tokens: int = None,
59
+ early_stopping: Union[bool, str] = None,
60
+ max_time: float = None,
61
+
62
+ llama_type: bool = None,
63
+ debug: bool = False,
64
+ save_dir: str = None,
65
+ share: bool = True,
66
+ local_files_only: bool = False,
67
+ resume_download: bool = True,
68
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
69
+
70
+ src_lang: str = "English",
71
+ tgt_lang: str = "Russian",
72
+
73
+ gradio: bool = True,
74
+ gradio_avoid_processing_markdown: bool = False,
75
+ chat: bool = True,
76
+ chat_history: int = 4096, # character length of chat context/history
77
+ stream_output: bool = True,
78
+ show_examples: bool = None,
79
+ verbose: bool = False,
80
+ h2ocolors: bool = True,
81
+ height: int = 400,
82
+ show_lora: bool = True,
83
+ # set to True to load --base_model after client logs in,
84
+ # to be able to free GPU memory when model is swapped
85
+ login_mode_if_model0: bool = False,
86
+
87
+ sanitize_user_prompt: bool = True,
88
+ sanitize_bot_response: bool = True,
89
+
90
+ extra_model_options: typing.List[str] = [],
91
+ extra_lora_options: typing.List[str] = [],
92
+
93
+ score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
94
+ auto_score: bool = True,
95
+
96
+ eval_sharegpt_prompts_only: int = 0,
97
+ eval_sharegpt_prompts_only_seed: int = 1234,
98
+ eval_sharegpt_as_output: bool = False,
99
+ ):
100
+ # allow set token directly
101
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
102
+
103
+ if is_public:
104
+ temperature = 0.4
105
+ top_p = 0.85
106
+ top_k = 70
107
+ do_sample = True
108
+ if is_low_mem:
109
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
110
+ load_8bit = True
111
+ else:
112
+ base_model = 'h2oai/h2ogpt-oasst1-512-20b'
113
+ if is_low_mem:
114
+ load_8bit = True
115
+ if is_hf:
116
+ # must override share if in spaces
117
+ share = False
118
+ save_dir = os.getenv('SAVE_DIR', save_dir)
119
+
120
+ # get defaults
121
+ model_lower = base_model.lower()
122
+ if not gradio:
123
+ # force, else not single response like want to look at
124
+ stream_output = False
125
+ # else prompt removal can mess up output
126
+ chat = False
127
+
128
+ placeholder_instruction, placeholder_input, \
129
+ stream_output, show_examples, \
130
+ prompt_type, temperature, top_p, top_k, num_beams, \
131
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
132
+ repetition_penalty, num_return_sequences, \
133
+ do_sample, \
134
+ src_lang, tgt_lang, \
135
+ examples, \
136
+ task_info = \
137
+ get_generate_params(model_lower, chat,
138
+ stream_output, show_examples,
139
+ prompt_type, temperature, top_p, top_k, num_beams,
140
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
141
+ repetition_penalty, num_return_sequences,
142
+ do_sample,
143
+ )
144
+
145
+ if not gradio:
146
+ if eval_sharegpt_prompts_only > 0:
147
+ # override default examples with shareGPT ones for human-level eval purposes only
148
+ eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
149
+ if not os.path.isfile(eval_filename):
150
+ os.system(
151
+ 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
152
+ import json
153
+ data = json.load(open(eval_filename, 'rt'))
154
+ # focus on data that starts with human, else likely chopped from other data
155
+ turn_start = 0 # odd in general
156
+ data = [x for x in data if len(x['conversations']) > turn_start + 1 and
157
+ x['conversations'][turn_start]['from'] == 'human' and
158
+ x['conversations'][turn_start + 1]['from'] == 'gpt']
159
+ np.random.seed(eval_sharegpt_prompts_only_seed)
160
+ example1 = examples[-1] # pick reference example
161
+ examples = []
162
+ responses = []
163
+ for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
164
+ assert data[i]['conversations'][turn_start]['from'] == 'human'
165
+ instruction = data[i]['conversations'][turn_start]['value']
166
+ assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
167
+ output = data[i]['conversations'][turn_start + 1]['value']
168
+ examplenew = example1.copy()
169
+ assert not chat, "No gradio must use chat=False, uses nochat isntruct"
170
+ examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
171
+ examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
172
+ examplenew[eval_func_param_names.index('context')] = '' # no context
173
+ examples.append(examplenew)
174
+ responses.append(output)
175
+
176
+ num_examples = len(examples)
177
+ scoring_path = 'scoring'
178
+ os.makedirs(scoring_path, exist_ok=True)
179
+ if eval_sharegpt_as_output:
180
+ used_base_model = 'gpt35'
181
+ used_lora_weights = ''
182
+ else:
183
+ used_base_model = str(base_model.split('/')[-1])
184
+ used_lora_weights = str(lora_weights.split('/')[-1])
185
+ eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
186
+ eval_sharegpt_prompts_only_seed,
187
+ eval_sharegpt_as_output,
188
+ used_base_model,
189
+ used_lora_weights)
190
+ eval_filename = os.path.join(scoring_path, eval_filename)
191
+
192
+ with torch.device("cuda"):
193
+ # ensure was set right above before examples generated
194
+ assert not stream_output, "stream_output=True does not make sense with example loop"
195
+ import time
196
+ from functools import partial
197
+
198
+ # get score model
199
+ smodel, stokenizer, sdevice = get_score_model(**locals())
200
+
201
+ if not eval_sharegpt_as_output:
202
+ model, tokenizer, device = get_model(**locals())
203
+ model_state = [model, tokenizer, device, base_model]
204
+ fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir)
205
+ else:
206
+ assert eval_sharegpt_prompts_only > 0
207
+
208
+ def get_response(*args, exi=0):
209
+ # assumes same ordering of examples and responses
210
+ yield responses[exi]
211
+
212
+ fun = get_response
213
+ t0 = time.time()
214
+ score_dump = []
215
+
216
+ import matplotlib.pyplot as plt
217
+
218
+ for exi, ex in enumerate(examples):
219
+ instruction = ex[eval_func_param_names.index('instruction_nochat')]
220
+ iinput = ex[eval_func_param_names.index('iinput_nochat')]
221
+ context = ex[eval_func_param_names.index('context')]
222
+ clear_torch_cache()
223
+ print("")
224
+ print("START" + "=" * 100)
225
+ print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
226
+ print("-" * 105)
227
+ # fun yields as generator, so have to iterate over it
228
+ # Also means likely do NOT want --stream_output=True, else would show all generations
229
+ for res in fun(*tuple(ex), exi=exi):
230
+ print(res)
231
+ if smodel:
232
+ score_with_prompt = False
233
+ if score_with_prompt:
234
+ data_point = dict(instruction=instruction, input=iinput, context=context)
235
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
236
+ prompt = prompter.generate_prompt(data_point)
237
+ else:
238
+ # just raw input and output
239
+ assert iinput in [None, ''] # should be no iinput
240
+ assert context in [None, ''] # should be no context
241
+ prompt = instruction
242
+ cutoff_len = 768 if is_low_mem else 2048
243
+ inputs = stokenizer(prompt, res,
244
+ return_tensors="pt",
245
+ truncation=True,
246
+ max_length=cutoff_len)
247
+ try:
248
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
249
+ except torch.cuda.OutOfMemoryError as e:
250
+ print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
251
+ traceback.print_exc()
252
+ score = 0.0
253
+ clear_torch_cache()
254
+ except (Exception, RuntimeError) as e:
255
+ if 'Expected all tensors to be on the same device' in str(e) or \
256
+ 'expected scalar type Half but found Float' in str(e) or \
257
+ 'probability tensor contains either' in str(e) or \
258
+ 'cublasLt ran into an error!' in str(e):
259
+ print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
260
+ flush=True)
261
+ traceback.print_exc()
262
+ score = 0.0
263
+ clear_torch_cache()
264
+ else:
265
+ raise
266
+ print("SCORE %s: %s" % (exi, score), flush=True)
267
+ score_dump.append(ex + [prompt, res, score])
268
+ # dump every score in case abort
269
+ df_scores = pd.DataFrame(score_dump,
270
+ columns=eval_func_param_names + eval_extra_columns)
271
+ df_scores.to_parquet(eval_filename, index=False)
272
+ # plot histogram so far
273
+ plt.figure(figsize=(10, 10))
274
+ plt.hist(df_scores['score'], bins=20)
275
+ score_avg = np.mean(df_scores['score'])
276
+ score_median = np.median(df_scores['score'])
277
+ plt.title("Score avg: %s median: %s" % (score_avg, score_median))
278
+ plt.savefig(eval_filename.replace('.parquet', '.png'))
279
+ plt.close()
280
+
281
+ print("END" + "=" * 102)
282
+ print("")
283
+ t2 = time.time()
284
+ print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
285
+ t1 = time.time()
286
+ print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
287
+ return eval_filename
288
+
289
+ if gradio:
290
+ go_gradio(**locals())
291
+
292
+
293
+ def get_device():
294
+ if torch.cuda.is_available():
295
+ device = "cuda"
296
+ else:
297
+ raise RuntimeError("only cuda supported")
298
+
299
+ return device
300
+
301
+
302
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
303
+ gpu_id=0,
304
+ use_auth_token=False):
305
+ """
306
+ Ensure model gets on correct device
307
+ :param base_model:
308
+ :param model_loader:
309
+ :param load_half:
310
+ :param model_kwargs:
311
+ :param reward_type:
312
+ :param gpu_id:
313
+ :param use_auth_token:
314
+ :return:
315
+ """
316
+ with init_empty_weights():
317
+ from transformers import AutoConfig
318
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
319
+ model = AutoModel.from_config(
320
+ config,
321
+ )
322
+
323
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
324
+ # NOTE: Some models require avoiding sharding some layers,
325
+ # then would pass no_split_module_classes and give list of those layers.
326
+ device_map = infer_auto_device_map(
327
+ model,
328
+ dtype=torch.float16 if load_half else torch.float32,
329
+ )
330
+ if hasattr(model, 'model'):
331
+ device_map_model = infer_auto_device_map(
332
+ model.model,
333
+ dtype=torch.float16 if load_half else torch.float32,
334
+ )
335
+ device_map.update(device_map_model)
336
+ print('device_map: %s' % device_map, flush=True)
337
+
338
+ if gpu_id >= 0:
339
+ # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
340
+ # So avoid for now, just put on first GPU, unless score_model, put on last
341
+ n_gpus = torch.cuda.device_count()
342
+ if reward_type:
343
+ device_map = {'': n_gpus - 1}
344
+ else:
345
+ device_map = {'': min(n_gpus - 1, gpu_id)}
346
+
347
+ load_in_8bit = model_kwargs.get('load_in_8bit', False)
348
+ model_kwargs['device_map'] = device_map
349
+
350
+ if load_in_8bit or not load_half:
351
+ model = model_loader.from_pretrained(
352
+ base_model,
353
+ **model_kwargs,
354
+ )
355
+ else:
356
+ model = model_loader.from_pretrained(
357
+ base_model,
358
+ **model_kwargs,
359
+ ).half()
360
+ return model
361
+
362
+
363
+ def get_model(
364
+ load_8bit: bool = False,
365
+ load_half: bool = True,
366
+ infer_devices: bool = True,
367
+ base_model: str = '',
368
+ tokenizer_base_model: str = '',
369
+ lora_weights: str = "",
370
+ gpu_id: int = 0,
371
+
372
+ llama_type: bool = None,
373
+ reward_type: bool = None,
374
+ local_files_only: bool = False,
375
+ resume_download: bool = True,
376
+ use_auth_token: Union[str, bool] = False,
377
+ compile: bool = True,
378
+ **kwargs,
379
+ ):
380
+ """
381
+
382
+ :param load_8bit: load model in 8-bit, not supported by all models
383
+ :param load_half: load model in 16-bit
384
+ :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
385
+ For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
386
+ So it is not the default
387
+ :param base_model: name/path of base model
388
+ :param tokenizer_base_model: name/path of tokenizer
389
+ :param lora_weights: name/path
390
+ :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
391
+ :param llama_type: whether LLaMa type model
392
+ :param reward_type: reward type model for sequence classification
393
+ :param local_files_only: use local files instead of from HF
394
+ :param resume_download: resume downloads from HF
395
+ :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
396
+ :parm compile: whether to compile torch model
397
+ :param kwargs:
398
+ :return:
399
+ """
400
+ print("Get %s model" % base_model, flush=True)
401
+ if lora_weights is not None and lora_weights.strip():
402
+ print("Get %s lora weights" % lora_weights, flush=True)
403
+ device = get_device()
404
+
405
+ if 'gpt2' in base_model.lower():
406
+ # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
407
+ load_8bit = False
408
+
409
+ assert base_model.strip(), (
410
+ "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
411
+ )
412
+ llama_type = llama_type or "llama" in base_model
413
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
414
+ if not tokenizer_base_model:
415
+ tokenizer_base_model = base_model
416
+
417
+ if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
418
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
419
+ local_files_only=local_files_only,
420
+ resume_download=resume_download,
421
+ use_auth_token=use_auth_token,
422
+ )
423
+ else:
424
+ tokenizer = tokenizer_loader
425
+
426
+ if isinstance(tokenizer, str):
427
+ # already a pipeline, tokenizer_loader is string for task
428
+ model = model_loader(tokenizer,
429
+ model=base_model,
430
+ device=0 if device == "cuda" else -1,
431
+ torch_dtype=torch.float16)
432
+ else:
433
+ assert device == "cuda", "Unsupported device %s" % device
434
+ model_kwargs = dict(local_files_only=local_files_only,
435
+ torch_dtype=torch.float16,
436
+ resume_download=resume_download,
437
+ use_auth_token=use_auth_token)
438
+ if 'mbart-' not in base_model.lower():
439
+ model_kwargs.update(dict(load_in_8bit=load_8bit,
440
+ device_map={"": 0} if load_8bit else "auto",
441
+ ))
442
+ if 'OpenAssistant/reward-model'.lower() in base_model.lower():
443
+ # could put on other GPUs
444
+ model_kwargs['device_map'] = {"": 0}
445
+ model_kwargs.pop('torch_dtype', None)
446
+
447
+ if not lora_weights:
448
+ with torch.device("cuda"):
449
+ if infer_devices:
450
+ model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
451
+ gpu_id=gpu_id, use_auth_token=use_auth_token)
452
+ else:
453
+ if load_half and not load_8bit:
454
+ model = model_loader.from_pretrained(
455
+ base_model,
456
+ **model_kwargs).half()
457
+ else:
458
+ model = model_loader.from_pretrained(
459
+ base_model,
460
+ **model_kwargs)
461
+ elif load_8bit:
462
+ model = model_loader.from_pretrained(
463
+ base_model,
464
+ **model_kwargs
465
+ )
466
+ model = PeftModel.from_pretrained(
467
+ model,
468
+ lora_weights,
469
+ torch_dtype=torch.float16,
470
+ local_files_only=local_files_only,
471
+ resume_download=resume_download,
472
+ use_auth_token=use_auth_token,
473
+ device_map={"": 0}, # seems to be required
474
+ )
475
+ else:
476
+ with torch.device("cuda"):
477
+ model = model_loader.from_pretrained(
478
+ base_model,
479
+ **model_kwargs
480
+ )
481
+ model = PeftModel.from_pretrained(
482
+ model,
483
+ lora_weights,
484
+ torch_dtype=torch.float16,
485
+ local_files_only=local_files_only,
486
+ resume_download=resume_download,
487
+ use_auth_token=use_auth_token,
488
+ device_map="auto",
489
+ )
490
+ if load_half:
491
+ model.half()
492
+
493
+ # unwind broken decapoda-research config
494
+ if llama_type:
495
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
496
+ model.config.bos_token_id = 1
497
+ model.config.eos_token_id = 2
498
+ if 'gpt2' in base_model.lower():
499
+ # add special tokens that otherwise all share the same id
500
+ tokenizer.add_special_tokens({'bos_token': '<bos>',
501
+ 'eos_token': '<eos>',
502
+ 'pad_token': '<pad>'})
503
+
504
+ if not isinstance(tokenizer, str):
505
+ model.eval()
506
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile:
507
+ model = torch.compile(model)
508
+
509
+ return model, tokenizer, device
510
+
511
+
512
+ def get_score_model(**kwargs):
513
+ # score model
514
+ if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
515
+ score_all_kwargs = kwargs.copy()
516
+ score_all_kwargs['load_8bit'] = False
517
+ score_all_kwargs['load_half'] = False
518
+ score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
519
+ score_all_kwargs['tokenizer_base_model'] = ''
520
+ score_all_kwargs['lora_weights'] = ''
521
+ score_all_kwargs['llama_type'] = False
522
+ score_all_kwargs['compile'] = False
523
+ smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
524
+ else:
525
+ smodel, stokenizer, sdevice = None, None, None
526
+ return smodel, stokenizer, sdevice
527
+
528
+
529
+ def go_gradio(**kwargs):
530
+ # get default model
531
+ all_kwargs = kwargs.copy()
532
+ all_kwargs.update(locals())
533
+ if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
534
+ model0, tokenizer0, device = get_model(**all_kwargs)
535
+ else:
536
+ # if empty model, then don't load anything, just get gradio up
537
+ model0, tokenizer0, device = None, None, None
538
+ model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
539
+
540
+ # get score model
541
+ smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
542
+
543
+ if 'mbart-' in kwargs['model_lower']:
544
+ instruction_label_nochat = "Text to translate"
545
+ else:
546
+ instruction_label_nochat = "Instruction"
547
+ instruction_label = "You (Shift-Enter or push Submit to send message)"
548
+
549
+ title = 'h2oGPT'
550
+ if kwargs['verbose']:
551
+ description = f"""Model {kwargs['base_model']} Instruct dataset.
552
+ For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
553
+ Command: {str(' '.join(sys.argv))}
554
+ Hash: {get_githash()}
555
+ """
556
+ else:
557
+ description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
558
+ if is_public:
559
+ description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
560
+ if kwargs['load_8bit']:
561
+ description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
562
+ description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
563
+ description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
564
+
565
+ if kwargs['verbose']:
566
+ task_info_md = f"""
567
+ ### Task: {kwargs['task_info']}"""
568
+ else:
569
+ task_info_md = ''
570
+
571
+ css_code = """footer {visibility: hidden;}
572
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
573
+ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
574
+
575
+ from gradio.themes.utils import Color, colors, fonts, sizes
576
+ if kwargs['h2ocolors']:
577
+ h2o_yellow = Color(
578
+ name="yellow",
579
+ c50="#fffef2",
580
+ c100="#fff9e6",
581
+ c200="#ffecb3",
582
+ c300="#ffe28c",
583
+ c400="#ffd659",
584
+ c500="#fec925",
585
+ c600="#e6ac00",
586
+ c700="#bf8f00",
587
+ c800="#a67c00",
588
+ c900="#664d00",
589
+ c950="#403000",
590
+ )
591
+ h2o_gray = Color(
592
+ name="gray",
593
+ c50="#f2f2f2",
594
+ c100="#e5e5e5",
595
+ c200="#cccccc",
596
+ c300="#b2b2b2",
597
+ c400="#999999",
598
+ c500="#7f7f7f",
599
+ c600="#666666",
600
+ c700="#4c4c4c",
601
+ c800="#333333",
602
+ c900="#191919",
603
+ c950="#0d0d0d",
604
+ )
605
+ colors_dict = dict(primary_hue=h2o_yellow,
606
+ secondary_hue=h2o_yellow,
607
+ neutral_hue=h2o_gray,
608
+ spacing_size=sizes.spacing_md,
609
+ radius_size=sizes.radius_md,
610
+ text_size=sizes.text_md,
611
+ )
612
+ else:
613
+ colors_dict = dict(primary_hue=colors.indigo,
614
+ secondary_hue=colors.indigo,
615
+ neutral_hue=colors.gray,
616
+ spacing_size=sizes.spacing_md,
617
+ radius_size=sizes.radius_md,
618
+ text_size=sizes.text_md,
619
+ )
620
+
621
+ import gradio as gr
622
+
623
+ if kwargs['gradio_avoid_processing_markdown']:
624
+ from gradio_client import utils as client_utils
625
+ from gradio.components import Chatbot
626
+
627
+ # gradio has issue with taking too long to process input/output for markdown etc.
628
+ # Avoid for now, allow raw html to render, good enough for chatbot.
629
+ def _postprocess_chat_messages(self, chat_message: str):
630
+ if chat_message is None:
631
+ return None
632
+ elif isinstance(chat_message, (tuple, list)):
633
+ filepath = chat_message[0]
634
+ mime_type = client_utils.get_mimetype(filepath)
635
+ filepath = self.make_temp_copy_if_needed(filepath)
636
+ return {
637
+ "name": filepath,
638
+ "mime_type": mime_type,
639
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
640
+ "data": None, # These last two fields are filled in by the frontend
641
+ "is_file": True,
642
+ }
643
+ elif isinstance(chat_message, str):
644
+ return chat_message
645
+ else:
646
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
647
+
648
+ Chatbot._postprocess_chat_messages = _postprocess_chat_messages
649
+
650
+ demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
651
+ callback = gr.CSVLogger()
652
+ # css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
653
+ # demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
654
+
655
+ model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
656
+ if kwargs['base_model'].strip() not in model_options:
657
+ lora_options = [kwargs['base_model'].strip()] + model_options
658
+ lora_options = kwargs['extra_lora_options']
659
+ if kwargs['lora_weights'].strip() not in lora_options:
660
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
661
+ # always add in no lora case
662
+ # add fake space so doesn't go away in gradio dropdown
663
+ no_lora_str = no_model_str = '[None/Remove]'
664
+ lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
665
+ # always add in no model case so can free memory
666
+ # add fake space so doesn't go away in gradio dropdown
667
+ model_options = [no_model_str] + model_options
668
+
669
+ # transcribe, will be detranscribed before use by evaluate()
670
+ if not kwargs['lora_weights'].strip():
671
+ kwargs['lora_weights'] = no_lora_str
672
+
673
+ if not kwargs['base_model'].strip():
674
+ kwargs['base_model'] = no_model_str
675
+
676
+ # transcribe for gradio
677
+ kwargs['gpu_id'] = str(kwargs['gpu_id'])
678
+
679
+ no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
680
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
681
+ 'base_model') else no_model_msg
682
+ output_label0_model2 = no_model_msg
683
+
684
+ with demo:
685
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
686
+ # https://github.com/gradio-app/gradio/issues/3558
687
+ model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
688
+ model_state2 = gr.State([None, None, None, None])
689
+ model_options_state = gr.State([model_options])
690
+ lora_options_state = gr.State([lora_options])
691
+ gr.Markdown(
692
+ f"""
693
+ <h1 align="center"> {title}</h1>
694
+
695
+ {description}
696
+ {task_info_md}
697
+ """)
698
+ if is_hf:
699
+ gr.HTML(
700
+ '''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
701
+
702
+ # go button visible if
703
+ base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
704
+ go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
705
+ normal_block = gr.Row(visible=not base_wanted)
706
+ with normal_block:
707
+ with gr.Tabs():
708
+ with gr.Row():
709
+ col_nochat = gr.Column(visible=not kwargs['chat'])
710
+ with col_nochat: # FIXME: for model comparison, and check rest
711
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0)
712
+ instruction_nochat = gr.Textbox(
713
+ lines=4, label=instruction_label_nochat,
714
+ placeholder=kwargs['placeholder_instruction'],
715
+ )
716
+ iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
717
+ placeholder=kwargs['placeholder_input'])
718
+ submit_nochat = gr.Button("Submit")
719
+ flag_btn_nochat = gr.Button("Flag")
720
+ if kwargs['score_model']:
721
+ if not kwargs['auto_score']:
722
+ with gr.Column():
723
+ score_btn_nochat = gr.Button("Score last prompt & response")
724
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
725
+ else:
726
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
727
+ col_chat = gr.Column(visible=kwargs['chat'])
728
+ with col_chat:
729
+ with gr.Row():
730
+ text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
731
+ text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
732
+ height=kwargs['height'] or 400)
733
+ with gr.Row():
734
+ with gr.Column(scale=50):
735
+ instruction = gr.Textbox(
736
+ lines=4, label=instruction_label,
737
+ placeholder=kwargs['placeholder_instruction'],
738
+ )
739
+ with gr.Row(): # .style(equal_height=False, equal_width=False):
740
+ submit = gr.Button(value='Submit').style(full_width=False, size='sm')
741
+ stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
742
+ with gr.Row():
743
+ clear = gr.Button("New Conversation")
744
+ flag_btn = gr.Button("Flag")
745
+ if kwargs['score_model']:
746
+ if not kwargs['auto_score']: # FIXME: For checkbox model2
747
+ with gr.Column():
748
+ with gr.Row():
749
+ score_btn = gr.Button("Score last prompt & response").style(
750
+ full_width=False, size='sm')
751
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
752
+ score_res2 = gr.Row(visible=False)
753
+ with score_res2:
754
+ score_btn2 = gr.Button("Score last prompt & response 2").style(
755
+ full_width=False, size='sm')
756
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
757
+ else:
758
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
759
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
760
+ retry = gr.Button("Regenerate")
761
+ undo = gr.Button("Undo")
762
+ with gr.TabItem("Input/Output"):
763
+ with gr.Row():
764
+ if 'mbart-' in kwargs['model_lower']:
765
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
766
+ value=kwargs['src_lang'],
767
+ label="Input Language")
768
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
769
+ value=kwargs['tgt_lang'],
770
+ label="Output Language")
771
+ with gr.TabItem("Expert"):
772
+ with gr.Row():
773
+ with gr.Column():
774
+ stream_output = gr.components.Checkbox(label="Stream output",
775
+ value=kwargs['stream_output'])
776
+ prompt_type = gr.Dropdown(prompt_types_strings,
777
+ value=kwargs['prompt_type'], label="Prompt Type",
778
+ visible=not is_public)
779
+ prompt_type2 = gr.Dropdown(prompt_types_strings,
780
+ value=kwargs['prompt_type'], label="Prompt Type Model 2",
781
+ visible=not is_public and False)
782
+ do_sample = gr.Checkbox(label="Sample", info="Enable sampler, required for use of temperature, top_p, top_k",
783
+ value=kwargs['do_sample'])
784
+ temperature = gr.Slider(minimum=0.01, maximum=3,
785
+ value=kwargs['temperature'],
786
+ label="Temperature",
787
+ info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
788
+ top_p = gr.Slider(minimum=0, maximum=1,
789
+ value=kwargs['top_p'], label="Top p",
790
+ info="Cumulative probability of tokens to sample from")
791
+ top_k = gr.Slider(
792
+ minimum=0, maximum=100, step=1,
793
+ value=kwargs['top_k'], label="Top k",
794
+ info='Num. tokens to sample from'
795
+ )
796
+ max_beams = 8 if not is_low_mem else 2
797
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
798
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
799
+ info="Number of searches for optimal overall probability. "
800
+ "Uses more GPU memory/compute")
801
+ max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
802
+ max_new_tokens = gr.Slider(
803
+ minimum=1, maximum=max_max_new_tokens, step=1,
804
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
805
+ )
806
+ min_new_tokens = gr.Slider(
807
+ minimum=0, maximum=max_max_new_tokens, step=1,
808
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
809
+ )
810
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
811
+ value=kwargs['early_stopping'])
812
+ max_max_time = 60 * 5 if not is_low_mem else 60
813
+ max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
814
+ value=min(max_max_time, kwargs['max_time']), label="Max. time",
815
+ info="Max. time to search optimal output.")
816
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
817
+ value=kwargs['repetition_penalty'],
818
+ label="Repetition Penalty")
819
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
820
+ value=kwargs['num_return_sequences'],
821
+ label="Number Returns", info="Must be <= num_beams",
822
+ visible=not is_public)
823
+ iinput = gr.Textbox(lines=4, label="Input",
824
+ placeholder=kwargs['placeholder_input'],
825
+ visible=not is_public)
826
+ context = gr.Textbox(lines=3, label="System Pre-Context",
827
+ info="Directly pre-appended without prompt processing",
828
+ visible=not is_public and not kwargs['chat'])
829
+ chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
830
+ visible=not is_public)
831
+
832
+ with gr.TabItem("Models"):
833
+ load_msg = "Load-Unload Model/LORA" if not is_public \
834
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
835
+ load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
836
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
837
+ compare_checkbox = gr.components.Checkbox(label="Compare Mode",
838
+ value=False, visible=not is_public)
839
+ with gr.Row():
840
+ n_gpus = torch.cuda.device_count()
841
+ n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
842
+ with gr.Column():
843
+ with gr.Row(scale=1):
844
+ with gr.Column(scale=50):
845
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
846
+ value=kwargs['base_model'])
847
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
848
+ value=kwargs['lora_weights'], visible=kwargs['show_lora'])
849
+ with gr.Column(scale=1):
850
+ load_model_button = gr.Button(load_msg)
851
+ model_load8bit_checkbox = gr.components.Checkbox(
852
+ label="Load 8-bit [Not all models support]",
853
+ value=kwargs['load_8bit'])
854
+ model_infer_devices_checkbox = gr.components.Checkbox(
855
+ label="Infer Devices [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
856
+ value=kwargs['infer_devices'])
857
+ model_gpu = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
858
+ value=kwargs['gpu_id'])
859
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
860
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
861
+ visible=kwargs['show_lora'])
862
+ with gr.Row(scale=1):
863
+ with gr.Column(scale=50):
864
+ new_model = gr.Textbox(label="New Model HF name/path")
865
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
866
+ with gr.Column(scale=1):
867
+ add_model_button = gr.Button("Add new model name")
868
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
869
+ col_model2 = gr.Column(visible=False)
870
+ with col_model2:
871
+ with gr.Row(scale=1):
872
+ with gr.Column(scale=50):
873
+ model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
874
+ value=no_model_str)
875
+ lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
876
+ value=no_lora_str,
877
+ visible=kwargs['show_lora'])
878
+ with gr.Column(scale=1):
879
+ load_model_button2 = gr.Button(load_msg2)
880
+ model_load8bit_checkbox2 = gr.components.Checkbox(
881
+ label="Load 8-bit 2 [Not all models support]",
882
+ value=kwargs['load_8bit'])
883
+ model_infer_devices_checkbox2 = gr.components.Checkbox(
884
+ label="Infer Devices 2 [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
885
+ value=kwargs[
886
+ 'infer_devices'])
887
+ model_gpu2 = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
888
+ value=kwargs['gpu_id'])
889
+ # no model/lora loaded ever in model2 by default
890
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
891
+ lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
892
+ visible=kwargs['show_lora'])
893
+ with gr.TabItem("System"):
894
+ system_row = gr.Row(visible=not is_public)
895
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
896
+ admin_btn = gr.Button(value="admin", visible=is_public)
897
+ with system_row:
898
+ with gr.Column():
899
+ system_text = gr.Textbox(label='System Info')
900
+ system_btn = gr.Button(value='Get System Info')
901
+
902
+ zip_btn = gr.Button("Zip")
903
+ file_output = gr.File()
904
+
905
+ # Get flagged data
906
+ zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
907
+ zip_btn.click(zip_data1, inputs=None, outputs=file_output)
908
+
909
+ def check_admin_pass(x):
910
+ return gr.update(visible=x == admin_pass)
911
+
912
+ admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row)
913
+
914
+ # Get inputs to evaluate()
915
+ inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
916
+ from functools import partial
917
+ all_kwargs = kwargs.copy()
918
+ all_kwargs.update(locals())
919
+ kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
920
+ fun = partial(evaluate,
921
+ **kwargs_evaluate)
922
+ fun2 = partial(evaluate,
923
+ model_state2,
924
+ **kwargs_evaluate)
925
+
926
+ dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
927
+ size="sm",
928
+ )
929
+ dark_mode_btn.click(
930
+ None,
931
+ None,
932
+ None,
933
+ _js="""() => {
934
+ if (document.querySelectorAll('.dark').length) {
935
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
936
+ } else {
937
+ document.querySelector('body').classList.add('dark');
938
+ }
939
+ }""",
940
+ api_name="dark",
941
+ )
942
+
943
+ # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
944
+ def col_nochat_fun(x):
945
+ return gr.Column.update(visible=not x)
946
+
947
+ def col_chat_fun(x):
948
+ return gr.Column.update(visible=x)
949
+
950
+ def context_fun(x):
951
+ return gr.Textbox.update(visible=not x)
952
+
953
+ chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox") \
954
+ .then(col_chat_fun, chat, col_chat) \
955
+ .then(context_fun, chat, context)
956
+
957
+ # examples after submit or any other buttons for chat or no chat
958
+ if kwargs['examples'] is not None and kwargs['show_examples']:
959
+ gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
960
+
961
+ # Score
962
+ def score_last_response(*args, nochat=False, model2=False):
963
+ """ Similar to user() """
964
+ args_list = list(args)
965
+
966
+ max_length_tokenize = 512 if is_low_mem else 2048
967
+ cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
968
+
969
+ if not nochat:
970
+ history = args_list[-1]
971
+ if history is None:
972
+ if not model2:
973
+ # maybe only doing first model, no need to complain
974
+ print("Bad history in scoring last response, fix for now", flush=True)
975
+ history = []
976
+ if smodel is not None and \
977
+ stokenizer is not None and \
978
+ sdevice is not None and \
979
+ history is not None and len(history) > 0 and \
980
+ history[-1] is not None and \
981
+ len(history[-1]) >= 2:
982
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
983
+
984
+ question = history[-1][0]
985
+
986
+ answer = history[-1][1]
987
+ else:
988
+ return 'Response Score: NA'
989
+ else:
990
+ answer = args_list[-1]
991
+ instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
992
+ question = args_list[instruction_nochat_arg_id]
993
+
994
+ if question is None:
995
+ return 'Response Score: Bad Question'
996
+ if answer is None:
997
+ return 'Response Score: Bad Answer'
998
+
999
+ question = question[-cutoff_len:]
1000
+ answer = answer[-cutoff_len:]
1001
+
1002
+ inputs = stokenizer(question, answer,
1003
+ return_tensors="pt",
1004
+ truncation=True,
1005
+ max_length=max_length_tokenize).to(smodel.device)
1006
+ try:
1007
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
1008
+ except torch.cuda.OutOfMemoryError as e:
1009
+ print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
1010
+ del inputs
1011
+ traceback.print_exc()
1012
+ clear_torch_cache()
1013
+ return 'Response Score: GPU OOM'
1014
+ except (Exception, RuntimeError) as e:
1015
+ if 'Expected all tensors to be on the same device' in str(e) or \
1016
+ 'expected scalar type Half but found Float' in str(e) or \
1017
+ 'probability tensor contains either' in str(e) or \
1018
+ 'cublasLt ran into an error!' in str(e):
1019
+ print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
1020
+ flush=True)
1021
+ traceback.print_exc()
1022
+ clear_torch_cache()
1023
+ return 'Response Score: GPU Error'
1024
+ else:
1025
+ raise
1026
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
1027
+ return 'Response Score: {:.1%}'.format(score)
1028
+
1029
+ if kwargs['score_model']:
1030
+ score_args = dict(fn=score_last_response,
1031
+ inputs=inputs_list + [text_output],
1032
+ outputs=[score_text],
1033
+ )
1034
+ score_args2 = dict(fn=partial(score_last_response, model2=True),
1035
+ inputs=inputs_list + [text_output2],
1036
+ outputs=[score_text2],
1037
+ )
1038
+
1039
+ score_args_nochat = dict(fn=partial(score_last_response, nochat=True),
1040
+ inputs=inputs_list + [text_output_nochat],
1041
+ outputs=[score_text_nochat],
1042
+ )
1043
+ if not kwargs['auto_score']:
1044
+ score_event = score_btn.click(**score_args, queue=stream_output, api_name='score') \
1045
+ .then(**score_args2, queue=stream_output, api_name='score2')
1046
+ score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=stream_output,
1047
+ api_name='score_nochat')
1048
+
1049
+ def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
1050
+ """
1051
+ User that fills history for bot
1052
+ :param args:
1053
+ :param undo:
1054
+ :param sanitize_user_prompt:
1055
+ :param model2:
1056
+ :return:
1057
+ """
1058
+ args_list = list(args)
1059
+ user_message = args_list[0]
1060
+ input1 = args_list[1]
1061
+ context1 = args_list[2]
1062
+ if input1 and not user_message.endswith(':'):
1063
+ user_message1 = user_message + ":" + input1
1064
+ elif input1:
1065
+ user_message1 = user_message + input1
1066
+ else:
1067
+ user_message1 = user_message
1068
+ if sanitize_user_prompt:
1069
+ from better_profanity import profanity
1070
+ user_message1 = profanity.censor(user_message1)
1071
+
1072
+ history = args_list[-1]
1073
+ if undo and history:
1074
+ history.pop()
1075
+ args_list = args_list[:-1] # FYI, even if unused currently
1076
+ if history is None:
1077
+ if not model2:
1078
+ # no need to complain so often unless model1
1079
+ print("Bad history, fix for now", flush=True)
1080
+ history = []
1081
+ # ensure elements not mixed across models as output,
1082
+ # even if input is currently same source
1083
+ history = history.copy()
1084
+ if undo:
1085
+ return history
1086
+ else:
1087
+ # FIXME: compare, same history for now
1088
+ return history + [[user_message1, None]]
1089
+
1090
+ def bot(*args, retry=False):
1091
+ """
1092
+ bot that consumes history for user input
1093
+ instruction (from input_list) itself is not consumed by bot
1094
+ :param args:
1095
+ :param retry:
1096
+ :return:
1097
+ """
1098
+ args_list = list(args).copy()
1099
+ history = args_list[-1] # model_state is -2
1100
+ if retry and history:
1101
+ history.pop()
1102
+ if not history:
1103
+ print("No history", flush=True)
1104
+ return
1105
+ # ensure output will be unique to models
1106
+ history = history.copy()
1107
+ instruction1 = history[-1][0]
1108
+ context1 = ''
1109
+ if kwargs['chat_history'] > 0:
1110
+ prompt_type_arg_id = eval_func_param_names.index('prompt_type')
1111
+ prompt_type1 = args_list[prompt_type_arg_id]
1112
+ chat_arg_id = eval_func_param_names.index('chat')
1113
+ chat1 = args_list[chat_arg_id]
1114
+ context1 = ''
1115
+ for histi in range(len(history) - 1):
1116
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
1117
+ context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
1118
+ '<br>', '\n')
1119
+ if not context1.endswith('\n'):
1120
+ context1 += '\n'
1121
+ if context1 and not context1.endswith('\n'):
1122
+ context1 += '\n' # ensure if terminates abruptly, then human continues on next line
1123
+ args_list[0] = instruction1 # override original instruction with history from user
1124
+ # only include desired chat history
1125
+ args_list[2] = context1[-kwargs['chat_history']:]
1126
+ model_state1 = args_list[-2]
1127
+ if model_state1[0] is None or model_state1[0] == no_model_str:
1128
+ return
1129
+ args_list = args_list[:-2]
1130
+ fun1 = partial(evaluate,
1131
+ model_state1,
1132
+ **kwargs_evaluate)
1133
+ try:
1134
+ for output in fun1(*tuple(args_list)):
1135
+ bot_message = output
1136
+ history[-1][1] = bot_message
1137
+ yield history
1138
+ except StopIteration:
1139
+ yield history
1140
+ except RuntimeError as e:
1141
+ if "generator raised StopIteration" in str(e):
1142
+ # assume last entry was bad, undo
1143
+ history.pop()
1144
+ yield history
1145
+ raise
1146
+ except Exception as e:
1147
+ # put error into user input
1148
+ history[-1][0] = "Exception: %s" % str(e)
1149
+ yield history
1150
+ raise
1151
+ return
1152
+
1153
+ # NORMAL MODEL
1154
+ user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
1155
+ inputs=inputs_list + [text_output],
1156
+ outputs=text_output,
1157
+ )
1158
+ bot_args = dict(fn=bot,
1159
+ inputs=inputs_list + [model_state] + [text_output],
1160
+ outputs=text_output,
1161
+ )
1162
+ retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1163
+ inputs=inputs_list + [model_state] + [text_output],
1164
+ outputs=text_output,
1165
+ )
1166
+ undo_user_args = dict(fn=functools.partial(user, undo=True),
1167
+ inputs=inputs_list + [text_output],
1168
+ outputs=text_output,
1169
+ )
1170
+
1171
+ # MODEL2
1172
+ user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
1173
+ inputs=inputs_list + [text_output2],
1174
+ outputs=text_output2,
1175
+ )
1176
+ bot_args2 = dict(fn=bot,
1177
+ inputs=inputs_list + [model_state2] + [text_output2],
1178
+ outputs=text_output2,
1179
+ )
1180
+ retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1181
+ inputs=inputs_list + [model_state2] + [text_output2],
1182
+ outputs=text_output2,
1183
+ )
1184
+ undo_user_args2 = dict(fn=functools.partial(user, undo=True),
1185
+ inputs=inputs_list + [text_output2],
1186
+ outputs=text_output2,
1187
+ )
1188
+
1189
+ def clear_instruct():
1190
+ return gr.Textbox.update(value='')
1191
+
1192
+ if kwargs['auto_score']:
1193
+ # in case 2nd model, consume instruction first, so can clear quickly
1194
+ # bot doesn't consume instruction itself, just history from user, so why works
1195
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
1196
+ .then(**user_args2, queue=stream_output, api_name='instruction2') \
1197
+ .then(clear_instruct, None, instruction) \
1198
+ .then(**bot_args, api_name='instruction_bot') \
1199
+ .then(**score_args, api_name='instruction_bot_score') \
1200
+ .then(**bot_args2, api_name='instruction_bot2') \
1201
+ .then(**score_args2, api_name='instruction_bot_score2') \
1202
+ .then(clear_torch_cache)
1203
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
1204
+ .then(**user_args2, queue=stream_output, api_name='submit2') \
1205
+ .then(**bot_args, api_name='submit_bot') \
1206
+ .then(clear_instruct, None, instruction) \
1207
+ .then(**score_args, api_name='submit_bot_score') \
1208
+ .then(**bot_args2, api_name='submit_bot2') \
1209
+ .then(**score_args2, api_name='submit_bot_score2') \
1210
+ .then(clear_torch_cache)
1211
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
1212
+ .then(**user_args2, queue=stream_output, api_name='retry2') \
1213
+ .then(clear_instruct, None, instruction) \
1214
+ .then(**retry_bot_args, api_name='retry_bot') \
1215
+ .then(**score_args, api_name='retry_bot_score') \
1216
+ .then(**retry_bot_args2, api_name='retry_bot2') \
1217
+ .then(**score_args2, api_name='retry_bot_score2') \
1218
+ .then(clear_torch_cache)
1219
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
1220
+ .then(**score_args, api_name='undo_score') \
1221
+ .then(**undo_user_args2, queue=stream_output, api_name='undo2') \
1222
+ .then(**score_args2, api_name='undo_score2') \
1223
+ .then(clear_instruct, None, instruction)
1224
+ else:
1225
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
1226
+ .then(**user_args2, queue=stream_output, api_name='instruction2') \
1227
+ .then(clear_instruct, None, instruction) \
1228
+ .then(**bot_args, api_name='instruction_bot') \
1229
+ .then(**bot_args2, api_name='instruction_bot2') \
1230
+ .then(clear_torch_cache)
1231
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
1232
+ .then(**user_args2, queue=stream_output, api_name='submit2') \
1233
+ .then(clear_instruct, None, instruction) \
1234
+ .then(**bot_args, api_name='submit_bot') \
1235
+ .then(**bot_args2, api_name='submit_bot2') \
1236
+ .then(clear_torch_cache)
1237
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
1238
+ .then(**user_args2, queue=stream_output, api_name='retry2') \
1239
+ .then(clear_instruct, None, instruction) \
1240
+ .then(**retry_bot_args, api_name='retry_bot') \
1241
+ .then(**retry_bot_args2, api_name='retry_bot2') \
1242
+ .then(clear_torch_cache)
1243
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
1244
+ .then(**undo_user_args2, queue=stream_output, api_name='undo2')
1245
+
1246
+ # does both models
1247
+ clear.click(lambda: None, None, text_output, queue=False, api_name='clear') \
1248
+ .then(lambda: None, None, text_output2, queue=False, api_name='clear2')
1249
+ # FIXME: compare
1250
+ submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
1251
+ outputs=text_output_nochat, api_name='submit_nochat') \
1252
+ .then(**score_args_nochat, api_name='instruction_bot_score_nochat') \
1253
+ .then(clear_torch_cache)
1254
+
1255
+ def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
1256
+ # ensure old model removed from GPU memory
1257
+ if kwargs['debug']:
1258
+ print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
1259
+
1260
+ if isinstance(model_state_old[0], str) and model0 is not None:
1261
+ # best can do, move model loaded at first to CPU
1262
+ model0.cpu()
1263
+
1264
+ if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
1265
+ try:
1266
+ model_state_old[0].cpu()
1267
+ except Exception as e:
1268
+ # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
1269
+ print("Unable to put model on CPU: %s" % str(e), flush=True)
1270
+ del model_state_old[0]
1271
+ model_state_old[0] = None
1272
+
1273
+ if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
1274
+ del model_state_old[1]
1275
+ model_state_old[1] = None
1276
+
1277
+ clear_torch_cache()
1278
+ if kwargs['debug']:
1279
+ print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
1280
+
1281
+ if model_name is None or model_name == no_model_str:
1282
+ # no-op if no model, just free memory
1283
+ # no detranscribe needed for model, never go into evaluate
1284
+ lora_weights = no_lora_str
1285
+ return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
1286
+
1287
+ all_kwargs1 = all_kwargs.copy()
1288
+ all_kwargs1['base_model'] = model_name.strip()
1289
+ all_kwargs1['load_8bit'] = load_8bit
1290
+ all_kwargs1['infer_devices'] = infer_devices
1291
+ all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
1292
+ model_lower = model_name.strip().lower()
1293
+ if model_lower in inv_prompt_type_to_model_lower:
1294
+ prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
1295
+ else:
1296
+ prompt_type1 = prompt_type_old
1297
+
1298
+ # detranscribe
1299
+ if lora_weights == no_lora_str:
1300
+ lora_weights = ''
1301
+
1302
+ all_kwargs1['lora_weights'] = lora_weights.strip()
1303
+ model1, tokenizer1, device1 = get_model(**all_kwargs1)
1304
+ clear_torch_cache()
1305
+
1306
+ if kwargs['debug']:
1307
+ print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
1308
+ return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
1309
+
1310
+ def dropdown_prompt_type_list(x):
1311
+ return gr.Dropdown.update(value=x)
1312
+
1313
+ def chatbot_list(x, model_used_in):
1314
+ return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
1315
+
1316
+ load_model_args = dict(fn=load_model,
1317
+ inputs=[model_choice, lora_choice, model_state, prompt_type,
1318
+ model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
1319
+ outputs=[model_state, model_used, lora_used, prompt_type])
1320
+ prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
1321
+ chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1322
+ nochat_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output_nochat)
1323
+ if not is_public:
1324
+ load_model_event = load_model_button.click(**load_model_args) \
1325
+ .then(**prompt_update_args) \
1326
+ .then(**chatbot_update_args) \
1327
+ .then(**nochat_update_args) \
1328
+ .then(clear_torch_cache)
1329
+
1330
+ load_model_args2 = dict(fn=load_model,
1331
+ inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
1332
+ model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
1333
+ outputs=[model_state2, model_used2, lora_used2, prompt_type2])
1334
+ prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
1335
+ chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
1336
+ if not is_public:
1337
+ load_model_event2 = load_model_button2.click(**load_model_args2) \
1338
+ .then(**prompt_update_args2) \
1339
+ .then(**chatbot_update_args2) \
1340
+ .then(clear_torch_cache)
1341
+
1342
+ def dropdown_model_list(list0, x):
1343
+ new_state = [list0[0] + [x]]
1344
+ new_options = [*new_state[0]]
1345
+ return gr.Dropdown.update(value=x, choices=new_options), \
1346
+ gr.Dropdown.update(value=x, choices=new_options), \
1347
+ '', new_state
1348
+
1349
+ add_model_event = add_model_button.click(fn=dropdown_model_list,
1350
+ inputs=[model_options_state, new_model],
1351
+ outputs=[model_choice, model_choice2, new_model, model_options_state])
1352
+
1353
+ def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
1354
+ new_state = [list0[0] + [x]]
1355
+ new_options = [*new_state[0]]
1356
+ # don't switch drop-down to added lora if already have model loaded
1357
+ x1 = x if model_used1 == no_model_str else lora_used1
1358
+ x2 = x if model_used2 == no_model_str else lora_used2
1359
+ return gr.Dropdown.update(value=x1, choices=new_options), \
1360
+ gr.Dropdown.update(value=x2, choices=new_options), \
1361
+ '', new_state
1362
+
1363
+ add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
1364
+ inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, lora_used2],
1365
+ outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
1366
+
1367
+ go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
1368
+ .then(lambda: gr.update(visible=True), None, normal_block) \
1369
+ .then(**load_model_args).then(**prompt_update_args)
1370
+
1371
+ def compare_textbox_fun(x):
1372
+ return gr.Textbox.update(visible=x)
1373
+
1374
+ def compare_column_fun(x):
1375
+ return gr.Column.update(visible=x)
1376
+
1377
+ def compare_prompt_fun(x):
1378
+ return gr.Dropdown.update(visible=x)
1379
+
1380
+ compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, api_name="compare_checkbox") \
1381
+ .then(compare_column_fun, compare_checkbox, col_model2) \
1382
+ .then(compare_prompt_fun, compare_checkbox, prompt_type2) \
1383
+ .then(compare_textbox_fun, compare_checkbox, score_text2)
1384
+ # FIXME: add score_res2 in condition, but do better
1385
+
1386
+ # callback for logging flagged input/output
1387
+ callback.setup(inputs_list + [text_output], "flagged_data_points")
1388
+ flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
1389
+ api_name='flag')
1390
+ flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
1391
+ api_name='flag_nochat')
1392
+
1393
+ def get_system_info():
1394
+ return gr.Textbox.update(value=system_info_print())
1395
+
1396
+ system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
1397
+
1398
+ # don't pass text_output, don't want to clear output, just stop it
1399
+ # FIXME: have to click once to stop output and second time to stop GPUs going
1400
+ stop_btn.click(lambda: None, None, None,
1401
+ cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
1402
+ queue=False, api_name='stop').then(clear_torch_cache)
1403
+
1404
+ demo.queue(concurrency_count=1)
1405
+ favicon_path = "h2o-logo.svg"
1406
+ demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1407
+ favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
1408
+ print("Started GUI", flush=True)
1409
+ demo.block_thread()
1410
+
1411
+
1412
+ input_args_list = ['model_state']
1413
+ inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
1414
+
1415
+
1416
+ def get_inputs_list(inputs_dict, model_lower):
1417
+ """
1418
+ map gradio objects in locals() to inputs for evaluate().
1419
+ :param inputs_dict:
1420
+ :param model_lower:
1421
+ :return:
1422
+ """
1423
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
1424
+ inputs_list = []
1425
+ for k in inputs_list_names:
1426
+ if k == 'kwargs':
1427
+ continue
1428
+ if k in input_args_list + inputs_kwargs_list:
1429
+ # these are added via partial, not taken as input
1430
+ continue
1431
+ if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
1432
+ continue
1433
+ inputs_list.append(inputs_dict[k])
1434
+ return inputs_list
1435
+
1436
+
1437
+ eval_func_param_names = ['instruction',
1438
+ 'iinput',
1439
+ 'context',
1440
+ 'stream_output',
1441
+ 'prompt_type',
1442
+ 'temperature',
1443
+ 'top_p',
1444
+ 'top_k',
1445
+ 'num_beams',
1446
+ 'max_new_tokens',
1447
+ 'min_new_tokens',
1448
+ 'early_stopping',
1449
+ 'max_time',
1450
+ 'repetition_penalty',
1451
+ 'num_return_sequences',
1452
+ 'do_sample',
1453
+ 'chat',
1454
+ 'instruction_nochat',
1455
+ 'iinput_nochat',
1456
+ ]
1457
+
1458
+
1459
+ def evaluate(
1460
+ model_state,
1461
+ # START NOTE: Examples must have same order of parameters
1462
+ instruction,
1463
+ iinput,
1464
+ context,
1465
+ stream_output,
1466
+ prompt_type,
1467
+ temperature,
1468
+ top_p,
1469
+ top_k,
1470
+ num_beams,
1471
+ max_new_tokens,
1472
+ min_new_tokens,
1473
+ early_stopping,
1474
+ max_time,
1475
+ repetition_penalty,
1476
+ num_return_sequences,
1477
+ do_sample,
1478
+ chat,
1479
+ instruction_nochat,
1480
+ iinput_nochat,
1481
+ # END NOTE: Examples must have same order of parameters
1482
+ src_lang=None,
1483
+ tgt_lang=None,
1484
+ debug=False,
1485
+ save_dir=None,
1486
+ hard_stop_list=None,
1487
+ sanitize_bot_response=True,
1488
+ model_state0=None,
1489
+ **kwargs,
1490
+ ):
1491
+ if debug:
1492
+ locals_dict = locals().copy()
1493
+ locals_dict.pop('model_state', None)
1494
+ locals_dict.pop('model_state0', None)
1495
+ print(locals_dict)
1496
+
1497
+ no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
1498
+
1499
+ if model_state0 is None:
1500
+ # e.g. for no gradio case, set dummy value, else should be set
1501
+ model_state0 = [None, None, None, None]
1502
+
1503
+ if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
1504
+ # try to free-up original model (i.e. list was passed as reference)
1505
+ if model_state0 is not None and model_state0[0] is not None:
1506
+ model_state0[0].cpu()
1507
+ model_state0[0] = None
1508
+ # try to free-up original tokenizer (i.e. list was passed as reference)
1509
+ if model_state0 is not None and model_state0[1] is not None:
1510
+ model_state0[1] = None
1511
+ clear_torch_cache()
1512
+ model, tokenizer, device, base_model = model_state
1513
+ elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
1514
+ assert isinstance(model_state[0], str)
1515
+ model, tokenizer, device, base_model = model_state0
1516
+ else:
1517
+ raise AssertionError(no_model_msg)
1518
+
1519
+ if base_model is None:
1520
+ raise AssertionError(no_model_msg)
1521
+
1522
+ assert base_model.strip(), no_model_msg
1523
+ assert model, "Model is missing"
1524
+ assert tokenizer, "Tokenizer is missing"
1525
+
1526
+ # choose chat or non-chat mode
1527
+ if not chat:
1528
+ instruction = instruction_nochat
1529
+ iinput = iinput_nochat
1530
+
1531
+ data_point = dict(context=context, instruction=instruction, input=iinput)
1532
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
1533
+ prompt = prompter.generate_prompt(data_point)
1534
+
1535
+ if hard_stop_list is None:
1536
+ # acts like undo on user entry and bot response
1537
+ hard_stop_list = []
1538
+
1539
+ if isinstance(tokenizer, str):
1540
+ # pipeline
1541
+ if tokenizer == "summarization":
1542
+ key = 'summary_text'
1543
+ else:
1544
+ raise RuntimeError("No such task type %s" % tokenizer)
1545
+ # NOTE: uses max_length only
1546
+ yield model(prompt, max_length=max_new_tokens)[0][key]
1547
+
1548
+ if 'mbart-' in base_model.lower():
1549
+ assert src_lang is not None
1550
+ tokenizer.src_lang = languages_covered()[src_lang]
1551
+
1552
+ if chat:
1553
+ # override, ignore user change
1554
+ num_return_sequences = 1
1555
+ if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
1556
+ if prompt_type == 'human_bot':
1557
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1558
+ # stopping only starts once output is beyond prompt
1559
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1560
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
1561
+ encounters = [1, 2]
1562
+ elif prompt_type == 'instruct_vicuna':
1563
+ # even below is not enough, generic strings and many ways to encode
1564
+ stop_words = [
1565
+ '### Human:',
1566
+ """
1567
+ ### Human:""",
1568
+ """
1569
+ ### Human:
1570
+ """,
1571
+ '### Assistant:',
1572
+ """
1573
+ ### Assistant:""",
1574
+ """
1575
+ ### Assistant:
1576
+ """,
1577
+ ]
1578
+ encounters = [1, 2]
1579
+ else:
1580
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
1581
+ stop_words = ['### End']
1582
+ encounters = [1]
1583
+ stop_words_ids = [
1584
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
1585
+ # handle single token case
1586
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
1587
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
1588
+ # avoid padding in front of tokens
1589
+ if tokenizer.pad_token:
1590
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
1591
+ # handle fake \n added
1592
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
1593
+ # build stopper
1594
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
1595
+ else:
1596
+ stopping_criteria = StoppingCriteriaList()
1597
+
1598
+ # help to avoid errors like:
1599
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1600
+ # RuntimeError: expected scalar type Half but found Float
1601
+ # with - 256
1602
+ max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
1603
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1604
+ output_smallest = 30 * 4
1605
+ prompt = prompt[-cutoff_len - output_smallest:]
1606
+ inputs = tokenizer(prompt,
1607
+ return_tensors="pt",
1608
+ truncation=True,
1609
+ max_length=max_length_tokenize)
1610
+ if debug and len(inputs["input_ids"]) > 0:
1611
+ print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1612
+ input_ids = inputs["input_ids"].to(device)
1613
+ generation_config = GenerationConfig(
1614
+ temperature=float(temperature),
1615
+ top_p=float(top_p),
1616
+ top_k=top_k,
1617
+ num_beams=num_beams,
1618
+ do_sample=do_sample,
1619
+ repetition_penalty=float(repetition_penalty),
1620
+ num_return_sequences=num_return_sequences,
1621
+ renormalize_logits=True,
1622
+ remove_invalid_values=True,
1623
+ **kwargs,
1624
+ )
1625
+
1626
+ gen_kwargs = dict(input_ids=input_ids,
1627
+ generation_config=generation_config,
1628
+ return_dict_in_generate=True,
1629
+ output_scores=True,
1630
+ max_new_tokens=max_new_tokens, # prompt + new
1631
+ min_new_tokens=min_new_tokens, # prompt + new
1632
+ early_stopping=early_stopping, # False, True, "never"
1633
+ max_time=max_time,
1634
+ stopping_criteria=stopping_criteria,
1635
+ )
1636
+ if 'gpt2' in base_model.lower():
1637
+ gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
1638
+ elif 'mbart-' in base_model.lower():
1639
+ assert tgt_lang is not None
1640
+ tgt_lang = languages_covered()[tgt_lang]
1641
+ gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1642
+ else:
1643
+ gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
1644
+
1645
+ decoder = functools.partial(tokenizer.decode,
1646
+ skip_special_tokens=True,
1647
+ clean_up_tokenization_spaces=True,
1648
+ )
1649
+ decoder_raw = functools.partial(tokenizer.decode,
1650
+ skip_special_tokens=False,
1651
+ clean_up_tokenization_spaces=True,
1652
+ )
1653
+
1654
+ with torch.no_grad():
1655
+ # decoded tokenized prompt can deviate from prompt due to special characters
1656
+ inputs_decoded = decoder(input_ids[0])
1657
+ inputs_decoded_raw = decoder_raw(input_ids[0])
1658
+ if inputs_decoded == prompt:
1659
+ # normal
1660
+ pass
1661
+ elif inputs_decoded.lstrip() == prompt.lstrip():
1662
+ # sometimes extra space in front, make prompt same for prompt removal
1663
+ prompt = inputs_decoded
1664
+ elif inputs_decoded_raw == prompt:
1665
+ # some models specify special tokens that are part of normal prompt, so can't skip them
1666
+ inputs_decoded_raw = inputs_decoded
1667
+ decoder = decoder_raw
1668
+ else:
1669
+ print("WARNING: Special characters in prompt", flush=True)
1670
+ if stream_output:
1671
+ def generate(callback=None, **kwargs):
1672
+ # re-order stopping so Stream first and get out all chunks before stop for other reasons
1673
+ stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
1674
+ kwargs['stopping_criteria'] = StoppingCriteriaList()
1675
+ kwargs['stopping_criteria'].append(Stream(func=callback))
1676
+ for stopping_criteria1 in stopping_criteria0:
1677
+ kwargs['stopping_criteria'].append(stopping_criteria1)
1678
+
1679
+ try:
1680
+ model.generate(**kwargs)
1681
+ except torch.cuda.OutOfMemoryError as e:
1682
+ print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1683
+ flush=True)
1684
+ if kwargs['input_ids'] is not None:
1685
+ kwargs['input_ids'].cpu()
1686
+ kwargs['input_ids'] = None
1687
+ traceback.print_exc()
1688
+ clear_torch_cache()
1689
+ return
1690
+ except (Exception, RuntimeError) as e:
1691
+ if 'Expected all tensors to be on the same device' in str(e) or \
1692
+ 'expected scalar type Half but found Float' in str(e) or \
1693
+ 'probability tensor contains either' in str(e) or \
1694
+ 'cublasLt ran into an error!' in str(e):
1695
+ print(
1696
+ "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1697
+ flush=True)
1698
+ traceback.print_exc()
1699
+ clear_torch_cache()
1700
+ if raise_generate_gpu_exceptions:
1701
+ raise
1702
+ return
1703
+ else:
1704
+ raise
1705
+
1706
+ decoded_output = None
1707
+ for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
1708
+ decoded_output = decoder(output)
1709
+ if output[-1] in [tokenizer.eos_token_id]:
1710
+ if debug:
1711
+ print("HIT EOS", flush=True)
1712
+ break
1713
+ if any(ele in decoded_output for ele in hard_stop_list):
1714
+ raise StopIteration
1715
+ yield prompter.get_response(decoded_output, prompt=inputs_decoded,
1716
+ sanitize_bot_response=sanitize_bot_response)
1717
+ if save_dir and decoded_output:
1718
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1719
+ else:
1720
+ outputs = model.generate(**gen_kwargs)
1721
+ outputs = [decoder(s) for s in outputs.sequences]
1722
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
1723
+ sanitize_bot_response=sanitize_bot_response)
1724
+ if save_dir and outputs and len(outputs) >= 1:
1725
+ decoded_output = prompt + outputs[0]
1726
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1727
+
1728
+
1729
+ def get_generate_params(model_lower, chat,
1730
+ stream_output, show_examples,
1731
+ prompt_type, temperature, top_p, top_k, num_beams,
1732
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
1733
+ repetition_penalty, num_return_sequences,
1734
+ do_sample):
1735
+ use_defaults = False
1736
+ use_default_examples = True
1737
+ examples = []
1738
+ task_info = f"{prompt_type}"
1739
+ if model_lower:
1740
+ print(f"Using Model {model_lower}", flush=True)
1741
+ else:
1742
+ print("No model defined yet", flush=True)
1743
+
1744
+ min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1745
+ early_stopping = early_stopping if early_stopping is not None else False
1746
+ max_time_defaults = 60 * 3
1747
+ max_time = max_time if max_time is not None else max_time_defaults
1748
+
1749
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1750
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
1751
+
1752
+ # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1753
+ if show_examples is None:
1754
+ if chat:
1755
+ show_examples = False
1756
+ else:
1757
+ show_examples = True
1758
+
1759
+ summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
1760
+ Philipp: Sure you can use the new Hugging Face Deep Learning Container.
1761
+ Jeff: ok.
1762
+ Jeff: and how can I get started?
1763
+ Jeff: where can I find documentation?
1764
+ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
1765
+
1766
+ if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
1767
+ placeholder_instruction = summarize_example1
1768
+ placeholder_input = ""
1769
+ use_defaults = True
1770
+ use_default_examples = False
1771
+ examples += [
1772
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1773
+ 1.0, 1,
1774
+ False]]
1775
+ task_info = "Summarization"
1776
+ elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
1777
+ placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
1778
+ placeholder_input = ""
1779
+ use_defaults = True
1780
+ use_default_examples = True
1781
+ task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
1782
+ elif 'mbart-' in model_lower:
1783
+ placeholder_instruction = "The girl has long hair."
1784
+ placeholder_input = ""
1785
+ use_defaults = True
1786
+ use_default_examples = False
1787
+ examples += [
1788
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1789
+ 1.0, 1,
1790
+ False]]
1791
+ elif 'gpt2' in model_lower:
1792
+ placeholder_instruction = "The sky is"
1793
+ placeholder_input = ""
1794
+ prompt_type = prompt_type or 'plain'
1795
+ use_default_examples = True # some will be odd "continuations" but can be ok
1796
+ examples += [
1797
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1798
+ 1.0, 1,
1799
+ False]]
1800
+ task_info = "Auto-complete phrase, code, etc."
1801
+ use_defaults = True
1802
+ else:
1803
+ if chat:
1804
+ placeholder_instruction = "Enter a question or imperative."
1805
+ else:
1806
+ placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1807
+ placeholder_input = ""
1808
+ if model_lower:
1809
+ prompt_type = prompt_type or 'human_bot'
1810
+ else:
1811
+ prompt_type = ''
1812
+ examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
1813
+ stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
1814
+ False]]
1815
+ task_info = "No task"
1816
+ if prompt_type == 'instruct':
1817
+ task_info = "Answer question or follow imperative as instruction with optionally input."
1818
+ elif prompt_type == 'plain':
1819
+ task_info = "Auto-complete phrase, code, etc."
1820
+ elif prompt_type == 'human_bot':
1821
+ if chat:
1822
+ task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1823
+ else:
1824
+ task_info = "Ask question/imperative (input concatenated with instruction)"
1825
+
1826
+ # revert to plain if still nothing
1827
+ prompt_type = prompt_type or 'plain'
1828
+ if use_defaults:
1829
+ temperature = 1.0 if temperature is None else temperature
1830
+ top_p = 1.0 if top_p is None else top_p
1831
+ top_k = 40 if top_k is None else top_k
1832
+ num_beams = num_beams or 1
1833
+ max_new_tokens = max_new_tokens or 128
1834
+ repetition_penalty = repetition_penalty or 1.07
1835
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1836
+ do_sample = False if do_sample is None else do_sample
1837
+ else:
1838
+ temperature = 0.1 if temperature is None else temperature
1839
+ top_p = 0.75 if top_p is None else top_p
1840
+ top_k = 40 if top_k is None else top_k
1841
+ if chat:
1842
+ num_beams = num_beams or 1
1843
+ else:
1844
+ num_beams = num_beams or 4
1845
+ max_new_tokens = max_new_tokens or 256
1846
+ repetition_penalty = repetition_penalty or 1.07
1847
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1848
+ do_sample = False if do_sample is None else do_sample
1849
+ # doesn't include chat, instruction_nochat, iinput_nochat, added later
1850
+ params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1851
+ early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1852
+
1853
+ if use_default_examples:
1854
+ examples += [
1855
+ ["Translate English to French", "Good morning"] + params_list,
1856
+ ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1857
+ ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1858
+ [
1859
+ "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1860
+ ''] + params_list,
1861
+ ['Translate to German: My name is Arthur', ''] + params_list,
1862
+ ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1863
+ ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1864
+ ''] + params_list,
1865
+ ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1866
+ ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1867
+ ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1868
+ [
1869
+ "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1870
+ ''] + params_list,
1871
+ ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1872
+ [
1873
+ 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1874
+ ''] + params_list,
1875
+ ["""def area_of_rectangle(a: float, b: float):
1876
+ \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1877
+ ["""# a function in native python:
1878
+ def mean(a):
1879
+ return sum(a)/len(a)
1880
+
1881
+ # the same function using numpy:
1882
+ import numpy as np
1883
+ def mean(a):""", ''] + params_list,
1884
+ ["""X = np.random.randn(100, 100)
1885
+ y = np.random.randint(0, 1, 100)
1886
+
1887
+ # fit random forest classifier with 20 estimators""", ''] + params_list,
1888
+ ]
1889
+
1890
+ src_lang = "English"
1891
+ tgt_lang = "Russian"
1892
+
1893
+ # move to correct position
1894
+ for example in examples:
1895
+ example += [chat, '', '']
1896
+ # adjust examples if non-chat mode
1897
+ if not chat:
1898
+ example[eval_func_param_names.index('instruction_nochat')] = example[
1899
+ eval_func_param_names.index('instruction')]
1900
+ example[eval_func_param_names.index('instruction')] = ''
1901
+
1902
+ example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1903
+ example[eval_func_param_names.index('iinput')] = ''
1904
+
1905
+ return placeholder_instruction, placeholder_input, \
1906
+ stream_output, show_examples, \
1907
+ prompt_type, temperature, top_p, top_k, num_beams, \
1908
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1909
+ repetition_penalty, num_return_sequences, \
1910
+ do_sample, \
1911
+ src_lang, tgt_lang, \
1912
+ examples, \
1913
+ task_info
1914
+
1915
+
1916
+ def languages_covered():
1917
+ # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1918
+ covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1919
+ covered = covered.split(', ')
1920
+ covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1921
+ return covered
1922
+
1923
+
1924
+ def test_test_prompt(prompt_type='instruct', data_point=0):
1925
+ example_data_point = example_data_points[data_point]
1926
+ example_data_point.pop('output', None)
1927
+ return generate_prompt(example_data_point, prompt_type, False, False)
1928
+
1929
+
1930
+ if __name__ == "__main__":
1931
+ print("""
1932
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1933
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1934
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1935
+
1936
+ # generate without lora weights, no prompt
1937
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1938
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1939
+
1940
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1941
+ # OpenChatKit settings:
1942
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1943
+
1944
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1945
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1946
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1947
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
1948
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1949
+
1950
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1951
+
1952
+ must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
1953
+ can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1954
+ python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1955
+
1956
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
1957
+
1958
+ """, flush=True)
1959
+ fire.Fire(main)
client_test.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client test.
3
+
4
+ Run server:
5
+
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
7
+
8
+ NOTE: For private models, add --use-auth_token=True
9
+
10
+ NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
+ Currently, this will force model to be on a single GPU.
12
+
13
+ Then run this client as:
14
+
15
+ python client_test.py
16
+ """
17
+
18
+ debug = False
19
+
20
+ import os
21
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
22
+ from gradio_client import Client
23
+
24
+ client = Client("http://localhost:7860")
25
+ if debug:
26
+ print(client.view_api(all_endpoints=True))
27
+
28
+ instruction = '' # only for chat=True
29
+ iinput = '' # only for chat=True
30
+ context = ''
31
+ # streaming output is supported, loops over and outputs each generation in streaming mode
32
+ # but leave stream_output=False for simple input/output mode
33
+ stream_output = False
34
+ prompt_type = 'human_bot'
35
+ temperature = 0.1
36
+ top_p = 0.75
37
+ top_k = 40
38
+ num_beams = 1
39
+ max_new_tokens = 50
40
+ min_new_tokens = 0
41
+ early_stopping = False
42
+ max_time = 20
43
+ repetition_penalty = 1.0
44
+ num_return_sequences = 1
45
+ do_sample = True
46
+ # only these 2 below used if pass chat=False
47
+ chat = False
48
+ instruction_nochat = "Who are you?"
49
+ iinput_nochat = ''
50
+
51
+
52
+ def test_client_basic():
53
+ args = [instruction,
54
+ iinput,
55
+ context,
56
+ stream_output,
57
+ prompt_type,
58
+ temperature,
59
+ top_p,
60
+ top_k,
61
+ num_beams,
62
+ max_new_tokens,
63
+ min_new_tokens,
64
+ early_stopping,
65
+ max_time,
66
+ repetition_penalty,
67
+ num_return_sequences,
68
+ do_sample,
69
+ chat,
70
+ instruction_nochat,
71
+ iinput_nochat,
72
+ ]
73
+ api_name = '/submit_nochat'
74
+ res = client.predict(
75
+ *tuple(args),
76
+ api_name=api_name,
77
+ )
78
+ res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
79
+ print(res_dict)
80
+
81
+
82
+ import markdown # pip install markdown
83
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
84
+
85
+
86
+ def md_to_text(md):
87
+ html = markdown.markdown(md)
88
+ soup = BeautifulSoup(html, features='html.parser')
89
+ return soup.get_text()
90
+
91
+
92
+ if __name__ == '__main__':
93
+ test_client_basic()
finetune.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import random
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from datetime import datetime
9
+ from typing import List, Union
10
+ import fire
11
+ import numpy as np
12
+ import torch
13
+ from datasets import load_dataset, concatenate_datasets
14
+ import transformers
15
+ import torch.distributed as dist
16
+
17
+ from peft import (
18
+ prepare_model_for_int8_training,
19
+ LoraConfig,
20
+ get_peft_model,
21
+ get_peft_model_state_dict,
22
+ set_peft_model_state_dict,
23
+ )
24
+
25
+ from peft import mapping
26
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
27
+
28
+
29
+ def log(*args, **kwargs):
30
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
31
+ print(*args, **kwargs)
32
+
33
+
34
+ try:
35
+ import neptune
36
+ from transformers.integrations import NeptuneCallback
37
+
38
+ neptune_run = neptune.init_run(
39
+ source_files=[],
40
+ )
41
+ log("Connected to Neptune.")
42
+ except ImportError:
43
+ neptune_run = None
44
+ log("Please pip install neptune for tracking.")
45
+ except neptune.exceptions.NeptuneMissingApiTokenException:
46
+ neptune_run = None
47
+ os.environ["NEPTUNE_MODE"] = 'debug'
48
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
49
+
50
+ from enum import Enum
51
+
52
+
53
+ class PromptType(Enum):
54
+ plain = 0
55
+ instruct = 1
56
+ quality = 2
57
+ human_bot = 3
58
+ dai_faq = 4
59
+ summarize = 5
60
+ simple_instruct = 6
61
+ instruct_vicuna = 7
62
+ instruct_with_end = 8
63
+ human_bot_orig = 9
64
+
65
+
66
+ prompt_type_to_model_name = {
67
+ 'plain': [
68
+ 'EleutherAI/gpt-j-6B',
69
+ 'EleutherAI/pythia-6.9b',
70
+ 'EleutherAI/pythia-12b',
71
+ 'EleutherAI/pythia-12b-deduped',
72
+ 'EleutherAI/gpt-neox-20b',
73
+ 'decapoda-research/llama-7b-hf',
74
+ 'decapoda-research/llama-13b-hf',
75
+ 'decapoda-research/llama-30b-hf',
76
+ 'decapoda-research/llama-65b-hf',
77
+ 'facebook/mbart-large-50-many-to-many-mmt',
78
+ 'philschmid/bart-large-cnn-samsum',
79
+ 'philschmid/flan-t5-base-samsum',
80
+ 'gpt2',
81
+ 'distilgpt2',
82
+ ],
83
+ 'instruct': [],
84
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
85
+ 'quality': [],
86
+ 'human_bot': [
87
+ 'h2oai/h2ogpt-oig-oasst1-256-12b',
88
+ 'h2oai/h2ogpt-oasst1-512-12b',
89
+ 'h2oai/h2ogpt-oasst1-256-20b',
90
+ 'h2oai/h2ogpt-oasst1-512-20b',
91
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b',
92
+ ],
93
+ 'dai_faq': [],
94
+ 'summarize': [],
95
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
96
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
97
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
98
+ }
99
+
100
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
101
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
102
+
103
+ human = '<human>:'
104
+ bot = "<bot>:"
105
+
106
+ prompt_types_strings = []
107
+ for p in PromptType:
108
+ prompt_types_strings.extend([p.name])
109
+
110
+
111
+ prompt_types = []
112
+ for p in PromptType:
113
+ prompt_types.extend([p.name, p.value, str(p.value)])
114
+
115
+
116
+ # supported by huggingface evaluate
117
+ supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
118
+
119
+
120
+ def train(
121
+ save_code: bool = False,
122
+ run_id: int = None,
123
+
124
+ base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
125
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
126
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
127
+ # base_model: str = 'EleutherAI/gpt-neox-20b',
128
+ # base_model: str = 'EleutherAI/pythia-12b-deduped',
129
+ # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
130
+ # base_model: str = 'decapoda-research/llama-7b-hf',
131
+ # base_model: str = 'decapoda-research/llama-13b-hf',
132
+ # base_model: str = 'decapoda-research/llama-30b-hf',
133
+ # base_model: str = 'EleutherAI/gpt-j-6B',
134
+
135
+ # only needed if base_model is self-exported HF state without tokenizer
136
+ tokenizer_base_model: str = None,
137
+ # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
138
+
139
+ data_path: str = None,
140
+ data_col_dict: dict = None,
141
+ # data_path: str = "./dai_docs.train.json",
142
+ prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
143
+
144
+ valid_path: str = None,
145
+ # valid_path: str = "./dai_docs.valid.json",
146
+
147
+ # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
148
+ data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
149
+ data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
150
+ data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
151
+ data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
152
+
153
+ output_dir: str = None,
154
+
155
+ # LoRA checkpoint continuation
156
+ lora_weights: str = "",
157
+
158
+ # batching training hyperparams
159
+ batch_size: int = 128,
160
+ micro_batch_size: int = 4,
161
+ gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
162
+ fp16=True,
163
+
164
+ # general training hyperparams
165
+ num_epochs: float = 1,
166
+ learning_rate: float = 3e-4,
167
+
168
+ # validation settings
169
+ val_set_size: int = None,
170
+ val_metrics: List[str] = [],
171
+ eval_steps: int = None, # to control eval steps via steps
172
+ eval_epochs: float = None, # to control eval steps via epochs
173
+
174
+ # lora hyperparams
175
+ lora_r: int = 8,
176
+ lora_alpha: int = 16,
177
+ lora_dropout: float = 0.05,
178
+ lora_target_modules: List[str] = None,
179
+ llama_type: bool = None,
180
+
181
+ # llm hyperparams
182
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
183
+ group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
184
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
185
+ cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
186
+
187
+ # torch training params
188
+ ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
189
+ local_files_only: bool = False, # else will download new versions, normally unwanted
190
+ resume_download: bool = True,
191
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
192
+ warmup_steps: int = 100,
193
+ logging_steps: int = 1,
194
+ save_steps: int = None, # must be round multiple of eval_steps
195
+ add_eos_token: bool = False,
196
+ ):
197
+ # allow set token directly
198
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
199
+
200
+ prompt_type = str(prompt_type) # migration from integers
201
+ assert prompt_type in prompt_types
202
+
203
+ world_size = int(os.getenv("WORLD_SIZE", 1))
204
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
205
+ rank = int(os.getenv("RANK", 0))
206
+ print(f"local_rank: {local_rank}")
207
+ print(f"global rank: {rank}")
208
+
209
+ gpus = max(world_size, torch.cuda.device_count())
210
+ run_id = run_id or 0
211
+ if not data_path:
212
+ raise ValueError("No data_path provided")
213
+ if not output_dir:
214
+ output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
215
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
216
+ raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
217
+ else:
218
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
219
+ raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
220
+ device_map = "auto"
221
+
222
+ if save_code:
223
+ copy_code(run_id)
224
+ if tokenizer_base_model is None:
225
+ tokenizer_base_model = base_model
226
+ if llama_type is None:
227
+ llama_type = "llama" in base_model.lower()
228
+ assert (
229
+ base_model
230
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
231
+ gradient_accumulation_steps = batch_size // micro_batch_size
232
+ assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
233
+
234
+ device_map = "auto"
235
+
236
+ locals_dict = locals()
237
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
238
+ log(f"Training model with params:\n{locals_print}")
239
+ log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
240
+
241
+ max_memory = None
242
+ if gpus > 1:
243
+ if ddp:
244
+ log("Distributed: data parallel")
245
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
246
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
247
+ else:
248
+ free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
249
+ max_memory = f"{free_in_GB - 2}GB"
250
+ max_memory = {i: max_memory for i in range(gpus)}
251
+ log("world_size: %d" % world_size)
252
+ log("num_gpus: %d" % gpus)
253
+ log("max mem: %s" % max_memory)
254
+
255
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
256
+
257
+ model = model_loader.from_pretrained(
258
+ base_model,
259
+ load_in_8bit=True,
260
+ device_map=device_map,
261
+ torch_dtype=torch.float16,
262
+ max_memory=max_memory,
263
+ local_files_only=local_files_only,
264
+ resume_download=resume_download,
265
+ use_auth_token=use_auth_token,
266
+ )
267
+ if gpus > 1:
268
+ if not ddp:
269
+ log("model parallel")
270
+ model.is_parallelizable = True
271
+ model.model_parallel = True
272
+
273
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
274
+ local_files_only=local_files_only,
275
+ resume_download=resume_download,
276
+ use_auth_token=use_auth_token)
277
+
278
+ tokenizer.pad_token_id = 0 # different from the eos token
279
+ # when generating, we will use the logits of right-most token to predict the next token
280
+ # so the padding should be on the left,
281
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
282
+ tokenizer.padding_side = "left" # Allow batched inference
283
+
284
+ def tokenize(prompt, add_eos_token=True):
285
+ # there's probably a way to do this with the tokenizer settings
286
+ # but again, gotta move fast
287
+ result = tokenizer(
288
+ prompt,
289
+ truncation=True,
290
+ max_length=cutoff_len,
291
+ padding=False,
292
+ return_tensors=None,
293
+ )
294
+ if (
295
+ result["input_ids"][-1] != tokenizer.eos_token_id
296
+ and len(result["input_ids"]) < cutoff_len
297
+ and add_eos_token
298
+ ):
299
+ result["input_ids"].append(tokenizer.eos_token_id)
300
+ result["attention_mask"].append(1)
301
+
302
+ result["labels"] = result["input_ids"].copy()
303
+
304
+ return result
305
+
306
+ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
307
+ full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
308
+ tokenized_full_prompt = tokenize(full_prompt)
309
+ if not train_on_inputs:
310
+ user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
311
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
312
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
313
+ if add_eos:
314
+ user_prompt_len -= 1
315
+
316
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
317
+ tokenized_full_prompt["labels"] = [
318
+ -100
319
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
320
+ user_prompt_len:
321
+ ] # could be sped up, probably
322
+ return tokenized_full_prompt
323
+
324
+ if "gpt-neox" not in base_model or True:
325
+ model = prepare_model_for_int8_training(model)
326
+ else:
327
+ model = prepare_model_for_int8_training(
328
+ model,
329
+ output_embedding_layer_name="embed_out", # keep output logits in float32
330
+ layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
331
+ )
332
+ if lora_weights:
333
+ from peft import PeftModel
334
+ model = PeftModel.from_pretrained(
335
+ model,
336
+ lora_weights,
337
+ torch_dtype=torch.float16,
338
+ device_map=device_map,
339
+ local_files_only=local_files_only,
340
+ resume_download=resume_download,
341
+ use_auth_token=use_auth_token,
342
+ )
343
+ else:
344
+ if lora_target_modules is None:
345
+ base_model_lower = base_model.lower()
346
+ if base_model_lower in lora_mappings:
347
+ lora_target_modules_cand = [lora_mappings[base_model_lower]]
348
+ else:
349
+ lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
350
+ else:
351
+ lora_target_modules_cand = [lora_target_modules]
352
+
353
+ for lora_target_modules in lora_target_modules_cand:
354
+ try:
355
+ config = LoraConfig(
356
+ r=lora_r,
357
+ lora_alpha=lora_alpha,
358
+ target_modules=lora_target_modules,
359
+ lora_dropout=lora_dropout,
360
+ bias="none",
361
+ task_type="CAUSAL_LM",
362
+ )
363
+ model = get_peft_model(model, config)
364
+ break
365
+ except ValueError as e:
366
+ if "Target modules" in str(e) and "not found" in str(e):
367
+ continue
368
+ else:
369
+ raise
370
+ from peft import PeftModel
371
+ assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
372
+ if resume_from_checkpoint:
373
+ # Check the available weights and load them
374
+ checkpoint_name = os.path.join(
375
+ resume_from_checkpoint, "pytorch_model.bin"
376
+ ) # Full checkpoint
377
+ if not os.path.exists(checkpoint_name):
378
+ checkpoint_name = os.path.join(
379
+ resume_from_checkpoint, "adapter_model.bin"
380
+ ) # only LoRA model - LoRA config above has to fit
381
+ resume_from_checkpoint = False # So the trainer won't try loading its state
382
+ # The two files above have a different name depending on how they were saved, but are actually the same.
383
+ if os.path.exists(checkpoint_name):
384
+ log(f"Restarting from {checkpoint_name}")
385
+ adapters_weights = torch.load(checkpoint_name)
386
+ model = set_peft_model_state_dict(model, adapters_weights)
387
+ else:
388
+ log(f"Checkpoint {checkpoint_name} not found")
389
+
390
+ print(model)
391
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
392
+
393
+ metrics = {}
394
+ for name in supported_metrics:
395
+ if name in val_metrics:
396
+ import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
397
+ metrics[name] = evaluate.load(name)
398
+ log("Using Validation Metrics: %s" % str(list(metrics.keys())))
399
+ log("Supported Metrics: %s" % supported_metrics)
400
+
401
+ if val_set_size is None:
402
+ if len(metrics) == 0:
403
+ val_set_size = 1000
404
+ else:
405
+ val_set_size = 100
406
+ log("Auto set val_set_size %s" % val_set_size)
407
+ elif val_set_size < 1.0 and val_set_size != 0:
408
+ raise RuntimeError("Fractional validation size not supported.")
409
+
410
+ if valid_path:
411
+ data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
412
+ else:
413
+ if "json" in data_path:
414
+ data = load_dataset("json", data_files={"train": data_path})
415
+ else:
416
+ data = load_dataset(data_path)
417
+ data = data.rename_columns(data_col_dict or {})
418
+
419
+ valid_data = None
420
+ train_data_mix_in = None
421
+ valid_data_mix_in = None
422
+
423
+ if data_mix_in_path and data_mix_in_factor > 0:
424
+ # get mix-in training/validation data - to keep model "sane"
425
+ num_rows = data["train"].num_rows
426
+ log("Loading mix-in dataset: %s" % data_mix_in_path)
427
+ if "json" in data_mix_in_path:
428
+ data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
429
+ else:
430
+ data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
431
+ data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
432
+
433
+ # only get as much as we need to balance
434
+ valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
435
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
436
+ mixin_small = data_mix_in.train_test_split(
437
+ test_size=train_size + valid_size,
438
+ shuffle=True, seed=np.random.randint(10000),
439
+ )["test"]
440
+ if valid_size:
441
+ mixin_train_test = mixin_small.train_test_split(
442
+ test_size=valid_size, shuffle=False,
443
+ )
444
+ train_data_mix_in = mixin_train_test["train"]
445
+ valid_data_mix_in = mixin_train_test["test"]
446
+ else:
447
+ train_data_mix_in = mixin_small
448
+
449
+ if "prompt_type" not in train_data_mix_in.column_names:
450
+ train_data_mix_in = train_data_mix_in.add_column(
451
+ "prompt_type",
452
+ [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
453
+ )
454
+ log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
455
+ if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
456
+ valid_data_mix_in = valid_data_mix_in.add_column(
457
+ "prompt_type",
458
+ [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
459
+ )
460
+ log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
461
+ log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
462
+
463
+ # get our own training/validation data - for fine-tuning
464
+ if val_set_size > 0 and not valid_path and not data_mix_in_path:
465
+ # create valid split from train
466
+ train_val = data["train"].train_test_split(
467
+ test_size=val_set_size, shuffle=True, seed=42
468
+ )
469
+ train_data = train_val["train"]
470
+ valid_data = train_val["test"]
471
+ else:
472
+ train_data = data["train"]
473
+ if valid_path:
474
+ # use given valid split, has priority over data_mix_in_path
475
+ valid_data = data["valid"]
476
+ if "prompt_type" not in train_data.column_names:
477
+ train_data = train_data.add_column(
478
+ "prompt_type",
479
+ [prompt_type] * train_data.num_rows,
480
+ )
481
+ log("Added prompt type %s to training data" % prompt_type)
482
+ if valid_data and "prompt_type" not in valid_data.column_names:
483
+ valid_data = valid_data.add_column(
484
+ "prompt_type",
485
+ [prompt_type] * valid_data.num_rows,
486
+ )
487
+ log("Added prompt type %s to validation data" % prompt_type)
488
+
489
+ assert train_data is not None
490
+
491
+ # shuffle and tokenize data
492
+ if train_data_mix_in:
493
+ train_data = concatenate_datasets([train_data, train_data_mix_in])
494
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
495
+ train_set_size = len(train_data)
496
+
497
+ if valid_data and valid_data_mix_in:
498
+ valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
499
+ elif valid_data_mix_in:
500
+ valid_data = valid_data_mix_in
501
+
502
+ if valid_data:
503
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
504
+ val_set_size = len(valid_data)
505
+ else:
506
+ val_set_size = 0
507
+ log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
508
+ sample_row_dict = train_data[:1]
509
+ del sample_row_dict['input_ids']
510
+ del sample_row_dict['attention_mask']
511
+ del sample_row_dict['labels']
512
+ log("Sample input: %s" % sample_row_dict)
513
+
514
+ if neptune_run:
515
+ neptune_callback = NeptuneCallback(run=neptune_run)
516
+ callbacks = [neptune_callback]
517
+ else:
518
+ from transformers.integrations import TensorBoardCallback, is_tensorboard_available
519
+ if is_tensorboard_available:
520
+ # tensorboard --logdir=runs/
521
+ from torch.utils.tensorboard import SummaryWriter
522
+ tb_writer = SummaryWriter()
523
+ callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
524
+ else:
525
+ callbacks = []
526
+
527
+ expected_steps = (train_set_size * num_epochs) // batch_size
528
+ if eval_steps is None and eval_epochs is None:
529
+ # 20 evaluations for a run
530
+ eval_steps = max(1, int(expected_steps / 20))
531
+ log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
532
+ elif eval_steps is None and eval_epochs is not None:
533
+ eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
534
+ log("Auto converted eval_epochs=%s to eval_steps %s"
535
+ " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
536
+ if save_steps is None:
537
+ save_steps = eval_steps
538
+ log("Auto step save_steps to %s" % save_steps)
539
+ elif save_steps > eval_steps:
540
+ # save steps must be round multiple of eval_steps
541
+ save_steps0 = save_steps
542
+ save_steps = max(1, (save_steps//eval_steps)) * eval_steps
543
+ if save_steps0 != save_steps:
544
+ log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
545
+
546
+ def compute_metrics(eval_preds):
547
+ # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
548
+ inputs = eval_preds.inputs
549
+ label_ids = eval_preds.label_ids
550
+ predictions = eval_preds.predictions
551
+
552
+ #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
553
+ #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
554
+ #decoded_inputs = [pred.strip() for pred in decoded_inputs]
555
+
556
+ label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
557
+ # tokenizer behavior like generate time
558
+ decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
559
+ clean_up_tokenization_spaces=True)
560
+ decoded_labels = [pred.strip() for pred in decoded_labels]
561
+
562
+ predictions = np.argmax(predictions, -1)
563
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
564
+ # tokenizer behavior like generate time
565
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
566
+ clean_up_tokenization_spaces=True)
567
+ decoded_predictions = [pred.strip() for pred in decoded_predictions]
568
+
569
+ result = {}
570
+ for metric in metrics.values():
571
+ result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
572
+ # get rid of lists, for precision etc., for now
573
+ numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
574
+ result.update(numeric_results)
575
+ return result
576
+
577
+ # the callback that computes metrics of interest
578
+ if val_metrics:
579
+ trainer_kwargs = dict(compute_metrics=compute_metrics)
580
+ else:
581
+ trainer_kwargs = dict()
582
+
583
+ trainer = transformers.Trainer(
584
+ model=model,
585
+ tokenizer=tokenizer,
586
+ train_dataset=train_data,
587
+ eval_dataset=valid_data,
588
+ # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
589
+ args=transformers.Seq2SeqTrainingArguments(
590
+ per_device_train_batch_size=micro_batch_size,
591
+ per_device_eval_batch_size=1,
592
+ eval_accumulation_steps=10,
593
+ # predict_with_generate=True, # SEQ2SEQ only
594
+ include_inputs_for_metrics=True,
595
+ gradient_accumulation_steps=gradient_accumulation_steps,
596
+ warmup_steps=warmup_steps,
597
+ num_train_epochs=num_epochs,
598
+ learning_rate=learning_rate,
599
+ gradient_checkpointing=gradient_checkpointing,
600
+ fp16=fp16,
601
+ # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
602
+ optim="adamw_torch", # consider "adafactor" to save memory
603
+ logging_steps=logging_steps,
604
+ logging_strategy="steps",
605
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
606
+ save_strategy="steps",
607
+ eval_steps=eval_steps if val_set_size > 0 else None,
608
+ save_steps=save_steps,
609
+ output_dir=output_dir,
610
+ save_total_limit=3,
611
+ load_best_model_at_end=True if val_set_size > 0 else False,
612
+ ddp_find_unused_parameters=False if ddp else None,
613
+ group_by_length=group_by_length,
614
+ #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
615
+ #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
616
+ report_to='tensorboard' if not neptune_run else 'neptune',
617
+ ),
618
+ data_collator=transformers.DataCollatorForSeq2Seq(
619
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
620
+ ),
621
+ callbacks=callbacks,
622
+ **trainer_kwargs,
623
+ )
624
+ model.config.use_cache = False
625
+
626
+ old_state_dict = model.state_dict
627
+ model.state_dict = (
628
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
629
+ ).__get__(model, type(model))
630
+
631
+ if torch.__version__ >= "2" and sys.platform != "win32":
632
+ model = torch.compile(model)
633
+ # WIP (not generally replacing layers until pytorch 2.1)
634
+ torch.backends.cuda.enable_flash_sdp(True)
635
+
636
+ if gpus > 1 and not ddp:
637
+ assert trainer.is_model_parallel
638
+ else:
639
+ assert not trainer.is_model_parallel
640
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
641
+
642
+ model.save_pretrained(output_dir)
643
+
644
+ log("\n If there's a warning about missing keys above, please disregard :)")
645
+
646
+
647
+ def get_loaders(llama_type, model_name, reward_type):
648
+ # NOTE: Some models need specific new prompt_type
649
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
650
+ if llama_type:
651
+ from transformers import LlamaForCausalLM, LlamaTokenizer
652
+ model_loader = LlamaForCausalLM
653
+ tokenizer_loader = LlamaTokenizer
654
+ elif 'gpt2' in model_name.lower():
655
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
656
+ return GPT2LMHeadModel, GPT2Tokenizer
657
+ elif 'mbart-' in model_name.lower():
658
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
659
+ return MBartForConditionalGeneration, MBart50TokenizerFast
660
+ elif 't5' == model_name.lower() or \
661
+ 't5-' in model_name.lower() or \
662
+ 'flan-' in model_name.lower():
663
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
664
+ return T5ForConditionalGeneration, AutoTokenizer
665
+ elif 'bigbird' in model_name:
666
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
667
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
668
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
669
+ from transformers import pipeline
670
+ return pipeline, "summarization"
671
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
672
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
673
+ return AutoModelForSequenceClassification, AutoTokenizer
674
+ else:
675
+ from transformers import AutoTokenizer, AutoModelForCausalLM
676
+ model_loader = AutoModelForCausalLM
677
+ tokenizer_loader = AutoTokenizer
678
+ return model_loader, tokenizer_loader
679
+
680
+
681
+ def get_githash():
682
+ try:
683
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
684
+ except:
685
+ githash = ''
686
+ return githash
687
+
688
+
689
+ def copy_code(run_id):
690
+ """
691
+ copy code to track changes
692
+ :param run_id:
693
+ :return:
694
+ """
695
+ rnd_num = str(random.randint(0, 2 ** 31))
696
+ run_id = 'run_' + str(run_id)
697
+ os.makedirs(run_id, exist_ok=True)
698
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
699
+ me_file = os.path.basename(__file__)
700
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
701
+ if os.path.isfile(new_me):
702
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
703
+ shutil.copy(me_full, new_me)
704
+ else:
705
+ shutil.copy(me_full, new_me)
706
+
707
+
708
+ def get_prompt(prompt_type, chat, context, reduced):
709
+ if prompt_type in [-1, "-1", "plain"]:
710
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
711
+ terminate_response = []
712
+ elif prompt_type == 'simple_instruct':
713
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
714
+ terminate_response = []
715
+ elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
716
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
717
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
718
+
719
+ PreInstruct = """
720
+ ### Instruction:
721
+ """
722
+
723
+ PreInput = """
724
+ ### Input:
725
+ """
726
+
727
+ PreResponse = """
728
+ ### Response:
729
+ """
730
+ if prompt_type in [7, "7", "instruct_with_end"]:
731
+ terminate_response = ['### End']
732
+ else:
733
+ terminate_response = None
734
+ elif prompt_type in [1, "1", "quality"]:
735
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
736
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
737
+
738
+ PreInstruct = """
739
+ ### Instruction:
740
+ """
741
+
742
+ PreInput = """
743
+ ### Input:
744
+ """
745
+
746
+ PreResponse = """
747
+ ### Response:
748
+ """
749
+ terminate_response = None
750
+ elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
751
+ if reduced or context or prompt_type in [2, "2", "human_bot"]:
752
+ preprompt = ''
753
+ else:
754
+ cur_date = time.strftime('%Y-%m-%d')
755
+ cur_time = time.strftime('%H:%M:%S %p %Z')
756
+
757
+ PRE_PROMPT = """\
758
+ Current Date: {}
759
+ Current Time: {}
760
+
761
+ """
762
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
763
+ start = human
764
+ promptB = promptA = '%s%s ' % (preprompt, start)
765
+
766
+ PreInstruct = ""
767
+
768
+ PreInput = None
769
+
770
+ PreResponse = bot
771
+
772
+ terminate_response = [start, PreResponse]
773
+ elif prompt_type in [3, "3", "dai_faq"]:
774
+ promptA = ''
775
+ promptB = 'Answer the following Driverless AI question.\n'
776
+
777
+ PreInstruct = """
778
+ ### Driverless AI frequently asked question:
779
+ """
780
+
781
+ PreInput = None
782
+
783
+ PreResponse = """
784
+ ### Driverless AI documentation answer:
785
+ """
786
+ terminate_response = ['\n\n']
787
+ elif prompt_type in [5, "5", "summarize"]:
788
+ promptA = promptB = PreInput = ''
789
+ PreInstruct = '## Main Text\n\n'
790
+ PreResponse = '\n\n## Summary\n\n'
791
+ terminate_response = None
792
+ elif prompt_type in [6, "6", "instruct_vicuna"]:
793
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
794
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
795
+
796
+ PreInstruct = """
797
+ ### Human:
798
+ """
799
+
800
+ PreInput = None
801
+
802
+ PreResponse = """
803
+ ### Assistant:
804
+ """
805
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
806
+ else:
807
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
808
+
809
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
810
+
811
+
812
+ def generate_prompt(data_point, prompt_type, chat, reduced):
813
+ context = data_point.get('context')
814
+ if context is None:
815
+ context = ''
816
+ instruction = data_point.get('instruction')
817
+ input = data_point.get('input')
818
+ output = data_point.get('output')
819
+ prompt_type = data_point.get('prompt_type', prompt_type)
820
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
821
+ promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
822
+
823
+ prompt = context
824
+
825
+ if input and promptA:
826
+ prompt += f"""{promptA}"""
827
+ elif promptB:
828
+ prompt += f"""{promptB}"""
829
+
830
+ if instruction and PreInstruct is not None and input and PreInput is not None:
831
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
832
+ prompt = inject_newline(prompt_type, prompt)
833
+ elif instruction and input and PreInstruct is None and PreInput is not None:
834
+ prompt += f"""{PreInput}{instruction}
835
+ {input}"""
836
+ prompt = inject_newline(prompt_type, prompt)
837
+ elif input and instruction and PreInput is None and PreInstruct is not None:
838
+ prompt += f"""{PreInstruct}{instruction}
839
+ {input}"""
840
+ prompt = inject_newline(prompt_type, prompt)
841
+ elif instruction and PreInstruct is not None:
842
+ prompt += f"""{PreInstruct}{instruction}"""
843
+ prompt = inject_newline(prompt_type, prompt)
844
+ elif input and PreInput is not None:
845
+ prompt += f"""{PreInput}{input}"""
846
+ prompt = inject_newline(prompt_type, prompt)
847
+ elif input and instruction and PreInput is not None:
848
+ prompt += f"""{PreInput}{instruction}{input}"""
849
+ prompt = inject_newline(prompt_type, prompt)
850
+ elif input and instruction and PreInstruct is not None:
851
+ prompt += f"""{PreInstruct}{instruction}{input}"""
852
+ prompt = inject_newline(prompt_type, prompt)
853
+ elif input and instruction:
854
+ # i.e. for simple_instruct
855
+ prompt += f"""{instruction}: {input}"""
856
+ prompt = inject_newline(prompt_type, prompt)
857
+ elif input:
858
+ prompt += f"""{input}"""
859
+ prompt = inject_newline(prompt_type, prompt)
860
+ elif instruction:
861
+ prompt += f"""{instruction}"""
862
+ prompt = inject_newline(prompt_type, prompt)
863
+
864
+ if PreResponse is not None:
865
+ prompt += f"""{PreResponse}"""
866
+ pre_response = PreResponse # Don't use strip
867
+ else:
868
+ pre_response = ''
869
+
870
+ if output:
871
+ prompt += f"""{output}"""
872
+
873
+ return prompt, pre_response, terminate_response
874
+
875
+
876
+ def inject_newline(prompt_type, prompt):
877
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
878
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
879
+ prompt += '\n'
880
+ return prompt
881
+
882
+
883
+ example_data_point0 = dict(instruction="Summarize",
884
+ input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
885
+ output="Ducks eat and swim at the lake.")
886
+
887
+ example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
888
+ output="Einstein.")
889
+
890
+ example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
891
+ output="Einstein.")
892
+
893
+ example_data_points = [example_data_point0, example_data_point1, example_data_point2]
894
+
895
+
896
+ def test_train_prompt(prompt_type='instruct', data_point=0):
897
+ example_data_point = example_data_points[data_point]
898
+ return generate_prompt(example_data_point, prompt_type, False, False)
899
+
900
+
901
+ def test_debug():
902
+ fire.Fire(train)
903
+
904
+
905
+ if __name__ == "__main__":
906
+ CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
907
+ CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
908
+ log(f"""
909
+ Example runs on 4 GPUs:
910
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
911
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
912
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
913
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
914
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
915
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
916
+
917
+ All metrics:
918
+ CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
919
+
920
+ # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
921
+ rippa>
922
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
923
+ ova>
924
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
925
+ timemachine>
926
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
927
+
928
+ """, flush=True)
929
+
930
+ if os.environ.get("LOCAL_RANK") is None:
931
+ # then not using torchrun, so can't do distributed, ensure CVD set
932
+ assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
933
+
934
+ fire.Fire(train)
h2o-logo.svg ADDED
prompter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from finetune import generate_prompt
2
+
3
+
4
+ class Prompter(object):
5
+ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
6
+ allowed_repeat_line_length=10):
7
+ self.prompt_type = prompt_type
8
+ data_point = dict(instruction='', input='', output='')
9
+ _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
10
+ self.debug = debug
11
+ self.chat = chat
12
+ self.stream_output = stream_output
13
+ self.repeat_penalty = repeat_penalty
14
+ self.allowed_repeat_line_length = allowed_repeat_line_length
15
+
16
+ def generate_prompt(self, data_point):
17
+ reduced = False
18
+ prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
19
+ if self.debug:
20
+ print("prompt: ", prompt, flush=True)
21
+ self.prompt = prompt
22
+ return prompt
23
+
24
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
25
+ if isinstance(outputs, str):
26
+ outputs = [outputs]
27
+ if self.debug:
28
+ print("output: ", '\n\n'.join(outputs), flush=True)
29
+ if prompt is not None:
30
+ self.prompt = prompt
31
+
32
+ def clean_response(response):
33
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
34
+ for word in meaningless_words:
35
+ response = response.replace(word, "")
36
+ if sanitize_bot_response:
37
+ from better_profanity import profanity
38
+ response = profanity.censor(response)
39
+ response = response.strip("\n")
40
+ return response
41
+
42
+ def clean_repeats(response):
43
+ lines = response.split('\n')
44
+ new_lines = []
45
+ [new_lines.append(line) for line in lines if
46
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
47
+ if self.debug and len(lines) != len(new_lines):
48
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
49
+ response = '\n'.join(new_lines)
50
+ return response
51
+
52
+ multi_output = len(outputs) > 1
53
+
54
+ for oi, output in enumerate(outputs):
55
+ if self.prompt_type in [0, '0', 'plain']:
56
+ output = clean_response(output)
57
+ else:
58
+ # find first instance of prereponse
59
+ # prompt sometimes has odd characters, that mutate length,
60
+ # so can't go by length alone
61
+ if self.pre_response:
62
+ outputi = output.find(prompt)
63
+ if outputi >= 0:
64
+ output = output[outputi + len(prompt):]
65
+ allow_terminate = True
66
+ else:
67
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
68
+ output = output[len(prompt) - len(self.pre_response):]
69
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
70
+ if self.pre_response in output:
71
+ output = output.split(self.pre_response)[1]
72
+ allow_terminate = True
73
+ else:
74
+ print("Failure of parsing: %s" % output, flush=True)
75
+ allow_terminate = False
76
+ else:
77
+ allow_terminate = True
78
+ output = output[len(prompt):]
79
+ # clean after subtract prompt out, so correct removal of pre_response
80
+ output = clean_response(output).strip()
81
+ if self.repeat_penalty:
82
+ output = clean_repeats(output).strip()
83
+ if self.terminate_response and allow_terminate:
84
+ finds = []
85
+ for term in self.terminate_response:
86
+ finds.append(output.find(term))
87
+ finds = [x for x in finds if x >= 0]
88
+ if len(finds) > 0:
89
+ termi = finds[0]
90
+ output = output[:termi].strip()
91
+ else:
92
+ output = output.strip()
93
+ else:
94
+ output = output.strip()
95
+ if multi_output:
96
+ # prefix with output counter
97
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
98
+ if oi > 0:
99
+ # post fix outputs with seperator
100
+ output += '\n'
101
+ outputs[oi] = output
102
+ # join all outputs, only one extra new line between outputs
103
+ output = '\n'.join(outputs)
104
+ if self.debug:
105
+ print("outputclean: ", '\n\n'.join(outputs), flush=True)
106
+ return output
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for generate (gradio server) and finetune
2
+ datasets==2.10.1
3
+ sentencepiece==0.1.97
4
+ accelerate==0.18.0
5
+ gradio==3.27.0
6
+ huggingface_hub==0.13.4
7
+ appdirs==1.4.4
8
+ fire==0.5.0
9
+ docutils==0.19
10
+ torch==2.0.0
11
+ evaluate==0.4.0
12
+ rouge_score==0.1.2
13
+ sacrebleu==2.3.1
14
+ scikit-learn==1.2.2
15
+ alt-profanity-check==1.2.2
16
+ better-profanity==0.6.1
17
+ numpy==1.24.2
18
+ pandas==1.5.3
19
+ matplotlib==3.7.1
20
+ loralib==0.1.1
21
+ bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
+ transformers==4.28.1
24
+ tokenizers==0.13.3
25
+
26
+ # optional for generate
27
+ pynvml==11.5.0
28
+ psutil==5.9.4
29
+
30
+ # optional for finetune
31
+ tensorboard==2.12.1
32
+ neptune==1.1.1
33
+
34
+ # for gradio client
35
+ gradio_client==0.1.3
36
+ beautifulsoup4==4.12.2
37
+ markdown==3.4.1
38
+
39
+ # data and testing
40
+ pytest==7.2.2
41
+ pytest-xdist==3.2.1
42
+ nltk==3.8.1
43
+ textstat==0.7.3
44
+ pandoc==2.3
45
+ pypandoc==1.11
46
+ openpyxl==3.1.2
47
+ lm_dataformat==0.0.20
48
+ bioc==2.0
stopping.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from queue import Queue
3
+ from threading import Thread
4
+ import collections.abc
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+
9
+
10
+ class StoppingCriteriaSub(StoppingCriteria):
11
+
12
+ def __init__(self, stops=[], encounters=[]):
13
+ super().__init__()
14
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
+ self.encounters = encounters
16
+ self.stops = [stop.to("cuda") for stop in stops]
17
+ self.num_stops = [0] * len(stops)
18
+
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ for stopi, stop in enumerate(self.stops):
21
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
22
+ self.num_stops[stopi] += 1
23
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
24
+ return True
25
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
+ return False
28
+
29
+
30
+ class Stream(StoppingCriteria):
31
+ """
32
+ This class can be used to callback during generation. Keep
33
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
34
+
35
+ Args:
36
+ func (`callable`):
37
+ A callable function to apply on first input in list every iteration of generation
38
+ """
39
+
40
+ def __init__(self, func=None):
41
+ self.func = func
42
+
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ if self.func is not None:
45
+ # only consume first of multiple responses
46
+ self.func(input_ids[0])
47
+ return False
48
+
49
+
50
+ class CallbackToGenerator(collections.abc.Generator):
51
+ """
52
+ A generator wrapper for a function that invokes a callback multiple times.
53
+
54
+ Calling `send` on the generator emits a value from one callback, and returns
55
+ the next.
56
+
57
+ Note this starts a background thread
58
+ """
59
+
60
+ def __init__(self, func, *args, callback=None, **kwargs):
61
+ self.func = func
62
+ self.args = args
63
+ self.kwargs = kwargs
64
+ self.callback = callback
65
+
66
+ self._ready_queue = Queue(1)
67
+ self._done_queue = Queue(1)
68
+ self._done_holder = [False]
69
+
70
+ # local to avoid reference cycles
71
+ ready_queue = self._ready_queue
72
+ done_queue = self._done_queue
73
+ done_holder = self._done_holder
74
+
75
+ def val_callback(value):
76
+ done_queue.put((False, value))
77
+ cmd, val = ready_queue.get()
78
+ if cmd == 'send':
79
+ return val
80
+ elif cmd == 'throw':
81
+ raise val
82
+ else:
83
+ assert False # pragma: no cover
84
+
85
+ def thread_func():
86
+ while True:
87
+ cmd, val = ready_queue.get()
88
+ if cmd == 'send' and val is not None:
89
+ done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
90
+ continue
91
+ break
92
+ try:
93
+ if cmd == 'throw':
94
+ raise val
95
+ ret = func(callback=val_callback, **self.kwargs)
96
+ raise StopIteration(ret) if ret is not None else StopIteration
97
+ except BaseException as e:
98
+ done_holder[0] = True
99
+ done_queue.put((True, e))
100
+
101
+ self._thread = Thread(target=thread_func)
102
+ self._thread.start()
103
+
104
+ def _put(self, *args):
105
+ if self._done_holder[0]:
106
+ raise StopIteration
107
+ self._ready_queue.put(args)
108
+ is_exception, val = self._done_queue.get()
109
+ if is_exception:
110
+ try:
111
+ raise val
112
+ finally:
113
+ # prevent val's traceback containing a reference cycle
114
+ del val
115
+ else:
116
+ return val
117
+
118
+ def send(self, value):
119
+ return self._put('send', value)
120
+
121
+ def throw(self, exc):
122
+ return self._put('throw', exc)
123
+
124
+ def close(self):
125
+ try:
126
+ self.throw(GeneratorExit)
127
+ except StopIteration:
128
+ self._thread.join()
129
+ except GeneratorExit:
130
+ self._thread.join()
131
+ except BaseException:
132
+ self._thread.join()
133
+ raise
134
+ else:
135
+ # yielded again, can't clean up the thread
136
+ raise RuntimeError('Task with callback ignored GeneratorExit')
137
+
138
+ def __del__(self):
139
+ self.close()
utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import random
4
+ import time
5
+ import traceback
6
+ import zipfile
7
+ from datetime import datetime
8
+ import filelock
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+
13
+
14
+ def set_seed(seed: int):
15
+ """
16
+ Sets the seed of the entire notebook so results are the same every time we run.
17
+ This is for REPRODUCIBILITY.
18
+ """
19
+ np.random.seed(seed)
20
+ random_state = np.random.RandomState(seed)
21
+ random.seed(seed)
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+ os.environ['PYTHONHASHSEED'] = str(seed)
27
+ return random_state
28
+
29
+
30
+ def flatten_list(lis):
31
+ """Given a list, possibly nested to any level, return it flattened."""
32
+ new_lis = []
33
+ for item in lis:
34
+ if type(item) == type([]):
35
+ new_lis.extend(flatten_list(item))
36
+ else:
37
+ new_lis.append(item)
38
+ return new_lis
39
+
40
+
41
+ def clear_torch_cache():
42
+ if torch.cuda.is_available:
43
+ torch.cuda.empty_cache()
44
+ torch.cuda.ipc_collect()
45
+ gc.collect()
46
+
47
+
48
+ def system_info():
49
+ import psutil
50
+
51
+ system = {}
52
+ # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
53
+ # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
54
+ temps = psutil.sensors_temperatures(fahrenheit=False)
55
+ if 'coretemp' in temps:
56
+ coretemp = temps['coretemp']
57
+ temp_dict = {k.label: k.current for k in coretemp}
58
+ for k, v in temp_dict.items():
59
+ system['CPU_C/%s' % k] = v
60
+
61
+ # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
62
+ from pynvml.smi import nvidia_smi
63
+ nvsmi = nvidia_smi.getInstance()
64
+
65
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
66
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
67
+ for k, v in gpu_power_dict.items():
68
+ system['GPU_W/%s' % k] = v
69
+
70
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
71
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
72
+ for k, v in gpu_temp_dict.items():
73
+ system['GPU_C/%s' % k] = v
74
+
75
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
76
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
77
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
78
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
79
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
80
+ for k, v in gpu_memory_frac_dict.items():
81
+ system[f'GPU_M/%s' % k] = v
82
+
83
+ return system
84
+
85
+
86
+ def system_info_print():
87
+ try:
88
+ df = pd.DataFrame.from_dict(system_info(), orient='index')
89
+ # avoid slamming GPUs
90
+ time.sleep(1)
91
+ return df.to_markdown()
92
+ except Exception as e:
93
+ return "Error: %s" % str(e)
94
+
95
+
96
+ def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
97
+ try:
98
+ return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
99
+ except Exception as e:
100
+ traceback.print_exc()
101
+ print('Exception in zipping: %s' % str(e))
102
+
103
+
104
+ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
105
+ if zip_file is None:
106
+ datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
107
+ host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
108
+ zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
109
+ assert root_dirs is not None
110
+
111
+ with zipfile.ZipFile(zip_file, "w") as expt_zip:
112
+ for root_dir in root_dirs:
113
+ if root_dir is None:
114
+ continue
115
+ for root, d, files in os.walk(root_dir):
116
+ for file in files:
117
+ file_to_archive = os.path.join(root, file)
118
+ assert os.path.exists(file_to_archive)
119
+ path_to_archive = os.path.relpath(file_to_archive, base_dir)
120
+ expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
121
+ return zip_file
122
+
123
+
124
+ def save_generate_output(output=None, base_model=None, save_dir=None):
125
+ try:
126
+ return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
127
+ except Exception as e:
128
+ traceback.print_exc()
129
+ print('Exception in saving: %s' % str(e))
130
+
131
+
132
+ def _save_generate_output(output=None, base_model=None, save_dir=None):
133
+ """
134
+ Save conversation to .json, row by row.
135
+ json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
136
+ Appends if file exists
137
+ """
138
+ assert save_dir, "save_dir must be provided"
139
+ if os.path.exists(save_dir) and not os.path.isdir(save_dir):
140
+ raise RuntimeError("save_dir already exists and is not a directory!")
141
+ os.makedirs(save_dir, exist_ok=True)
142
+ import json
143
+ if output[-10:] == '\n\n<human>:':
144
+ # remove trailing <human>:
145
+ output = output[:-10]
146
+ with filelock.FileLock("save_dir.lock"):
147
+ # lock logging in case have concurrency
148
+ with open(os.path.join(save_dir, "history.json"), "a") as f:
149
+ # just add [ at start, and ] at end, and have proper JSON dataset
150
+ f.write(
151
+ " " + json.dumps(
152
+ dict(text=output, time=time.ctime(), base_model=base_model)
153
+ ) + ",\n"
154
+ )