joaogante HF Staff commited on
Commit
e9837c5
·
1 Parent(s): 7f36015

add sink cache options

Browse files
README.md CHANGED
@@ -17,12 +17,17 @@ This implementation should match the `SinkCache` class present in `transformers<
17
 
18
 
19
  ## Model compatibility
 
20
 
21
 
22
  ## Additional Arguments
 
 
23
 
24
 
25
  ## Output Type changes
 
 
26
 
27
 
28
  ## Example usage
 
17
 
18
 
19
  ## Model compatibility
20
+ - Decoder-only models
21
 
22
 
23
  ## Additional Arguments
24
+ - `window_length` (`int`, defaults to `256`): The length of the context window.
25
+ - `num_sink_tokens` (`int`, defaults to `4`): The number of sink tokens. See the original paper for more information.
26
 
27
 
28
  ## Output Type changes
29
+ - When `return_dict_in_generate=True`, `output.past_key_values` will be a `SinkCache` instance. `SinkCache` is defined
30
+ in `generate.py`, in this repository.
31
 
32
 
33
  ## Example usage
custom_generate/generate.py CHANGED
@@ -193,7 +193,7 @@ class SinkCache(Cache):
193
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
194
 
195
 
196
- def generate(model, **kwargs):
197
- past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
198
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
199
  return generation_outputs
 
193
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
194
 
195
 
196
+ def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
197
+ past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
198
  generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
199
  return generation_outputs
custom_generate/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers>=4.53.0 # 4.52 results in an infinite loop