joaogante HF Staff commited on
Commit
f84586a
·
1 Parent(s): e9837c5

allow passing a sink cache back in

Browse files
custom_generate/generate.py CHANGED
@@ -194,6 +194,17 @@ class SinkCache(Cache):
194
 
195
 
196
  def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
197
- past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
 
 
 
 
 
 
 
 
 
 
 
198
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
199
  return generation_outputs
 
194
 
195
 
196
  def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
197
+ # compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result in an
198
+ # infinite loop. This is solved in transformers 4.53.
199
+ kwargs.pop("custom_generate", None)
200
+
201
+ # prepare the cache, it is was not passed.
202
+ past_key_values = kwargs.pop("past_key_values", None)
203
+ if past_key_values is None:
204
+ past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
205
+ elif not isinstance(past_key_values, SinkCache):
206
+ raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
207
+
208
+ # generate with the cache
209
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
210
  return generation_outputs
custom_generate/requirements.txt DELETED
@@ -1 +0,0 @@
1
- transformers>=4.53.0 # 4.52 results in an infinite loop