add sink cache options
Browse files- README.md +5 -0
- custom_generate/generate.py +2 -2
- custom_generate/requirements.txt +1 -0
README.md
CHANGED
@@ -17,12 +17,17 @@ This implementation should match the `SinkCache` class present in `transformers<
|
|
17 |
|
18 |
|
19 |
## Model compatibility
|
|
|
20 |
|
21 |
|
22 |
## Additional Arguments
|
|
|
|
|
23 |
|
24 |
|
25 |
## Output Type changes
|
|
|
|
|
26 |
|
27 |
|
28 |
## Example usage
|
|
|
17 |
|
18 |
|
19 |
## Model compatibility
|
20 |
+
- Decoder-only models
|
21 |
|
22 |
|
23 |
## Additional Arguments
|
24 |
+
- `window_length` (`int`, defaults to `256`): The length of the context window.
|
25 |
+
- `num_sink_tokens` (`int`, defaults to `4`): The number of sink tokens. See the original paper for more information.
|
26 |
|
27 |
|
28 |
## Output Type changes
|
29 |
+
- When `return_dict_in_generate=True`, `output.past_key_values` will be a `SinkCache` instance. `SinkCache` is defined
|
30 |
+
in `generate.py`, in this repository.
|
31 |
|
32 |
|
33 |
## Example usage
|
custom_generate/generate.py
CHANGED
@@ -193,7 +193,7 @@ class SinkCache(Cache):
|
|
193 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
194 |
|
195 |
|
196 |
-
def generate(model, **kwargs):
|
197 |
-
past_key_values = SinkCache(window_length=
|
198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
199 |
return generation_outputs
|
|
|
193 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
194 |
|
195 |
|
196 |
+
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
197 |
+
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
199 |
return generation_outputs
|
custom_generate/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
transformers>=4.53.0 # 4.52 results in an infinite loop
|