allow passing a sink cache back in
Browse files
custom_generate/generate.py
CHANGED
@@ -194,6 +194,17 @@ class SinkCache(Cache):
|
|
194 |
|
195 |
|
196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
199 |
return generation_outputs
|
|
|
194 |
|
195 |
|
196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
197 |
+
# compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result in an
|
198 |
+
# infinite loop. This is solved in transformers 4.53.
|
199 |
+
kwargs.pop("custom_generate", None)
|
200 |
+
|
201 |
+
# prepare the cache, it is was not passed.
|
202 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
203 |
+
if past_key_values is None:
|
204 |
+
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
205 |
+
elif not isinstance(past_key_values, SinkCache):
|
206 |
+
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
207 |
+
|
208 |
+
# generate with the cache
|
209 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
210 |
return generation_outputs
|
custom_generate/requirements.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
transformers>=4.53.0 # 4.52 results in an infinite loop
|
|
|
|