File size: 46,369 Bytes
e0be88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

from parameterized import parameterized

from transformers import set_seed
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
from transformers.testing_utils import (
    CaptureStderr,
    backend_device_count,
    backend_torch_accelerator_module,
    cleanup,
    get_gpu_count,
    is_torch_available,
    require_read_token,
    require_torch,
    require_torch_accelerator,
    require_torch_gpu,
    require_torch_multi_accelerator,
    require_torch_multi_gpu,
    slow,
    torch_device,
)
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal


if is_torch_available():
    import torch

    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        Cache,
        ClvpForCausalLM,
        DynamicCache,
        Gemma2Config,
        GenerationConfig,
        HybridCache,
        LlamaConfig,
        SlidingWindowCache,
        StaticCache,
        convert_and_export_with_cache,
        pipeline,
    )
    from transformers.integrations.executorch import export_with_dynamic_cache


TEST_CACHE_IMPLEMENTATIONS = [
    cache_name
    for cache_name in ALL_CACHE_IMPLEMENTATIONS
    # TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
    if cache_name != "mamba"
    # TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
    if cache_name != "offloaded_hybrid"
]


@require_torch
class CacheTest(unittest.TestCase):
    """Cache tests that don't require loading models"""

    def test_dynamic_cache_retrocompatibility(self):
        """Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
        legacy_cache = ()
        new_cache = DynamicCache()

        # Creates a new cache with 10 layers in both formats
        for layer_idx in range(10):
            new_key = torch.rand((2, 4, 8, 16))
            new_value = torch.rand((2, 4, 8, 16))
            new_cache.update(new_key, new_value, layer_idx)
            legacy_cache += ((new_key, new_value),)

        # Sanity check 1: they must have the same shapes
        self.assertTrue(len(legacy_cache), len(new_cache))
        for layer_idx in range(10):
            self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
            for key_value_idx in range(2):
                self.assertTrue(
                    legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
                )

        # Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
        # expected value
        self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)

        # Sanity check 3: they must be equal, and both support indexing
        for layer_idx in range(10):
            for key_value_idx in range(2):
                self.assertTrue(
                    torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
                )

        # Test 1: We can convert from legacy to new with no changes
        from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
        for layer_idx in range(10):
            for key_value_idx in range(2):
                self.assertTrue(
                    torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
                )

        # Test 2: We can convert from new to legacy with no changes
        to_legacy = new_cache.to_legacy_cache()
        for layer_idx in range(10):
            for key_value_idx in range(2):
                self.assertTrue(
                    torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
                )

    def test_reorder_cache_retrocompatibility(self):
        """Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
        legacy_reorder_fn = ClvpForCausalLM._reorder_cache  # An example of a legacy `_reorder_cache` function

        legacy_cache = ()
        new_cache = DynamicCache()

        # Creates a new cache with 10 layers in both formats
        for layer_idx in range(10):
            new_key = torch.rand((4, 4, 8, 16))
            new_value = torch.rand((4, 4, 8, 16))
            new_cache.update(new_key, new_value, layer_idx)
            legacy_cache += ((new_key, new_value),)

        # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
        # and batch_size=1
        beam_idx = torch.randint(low=0, high=4, size=(4,))

        legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
        new_cache.reorder_cache(beam_idx)

        # Let's check that the results are the same
        for layer_idx in range(10):
            for key_value_idx in range(2):
                self.assertTrue(
                    torch.allclose(
                        new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
                    )
                )

    def test_static_cache_mha_mqa_gqa(self):
        """
        Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
        attention (MQA)
        """

        def _random_kvs(config):
            # shape for key and values: (batch_size, num_heads, seq_len, head_dim)
            random_keys = torch.rand(
                (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
                device=torch_device,
            )
            random_values = torch.rand(
                (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
                device=torch_device,
            )
            return random_keys, random_values

        mha_config = LlamaConfig(num_attention_heads=32)
        mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
        cached_keys, cached_values = mha_static_cache.update(
            *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
        )
        self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
        self.assertTrue(cached_values.shape == (1, 32, 10, 128))

        gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
        gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
        cached_keys, cached_values = gqa_static_cache.update(
            *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
        )
        self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
        self.assertTrue(cached_values.shape == (1, 4, 10, 128))

        mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
        mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
        cached_keys, cached_values = mqa_static_cache.update(
            *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
        )
        self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
        self.assertTrue(cached_values.shape == (1, 1, 10, 128))


def _skip_on_failed_cache_prerequisites(test, cache_implementation):
    """Function to skip tests on failed cache prerequisites, given a cache implementation"""
    # Installed dependencies
    if cache_implementation == "quantized" and not is_optimum_quanto_available():
        test.skipTest("Quanto is not available")
    # Devices
    if "offloaded" in cache_implementation:
        has_accelerator = torch_device is not None and torch_device != "cpu"
        if not has_accelerator:
            test.skipTest("Offloaded caches require an accelerator")
        if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
            if backend_device_count(torch_device) != 1:
                test.skipTest("Offloaded static caches require exactly 1 accelerator")


class CacheIntegrationTest(unittest.TestCase):
    """Fast cache integration tests that share the same small model"""

    @classmethod
    def setUpClass(cls):
        # Load once and reuse across tests
        cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left")
        cls.model = AutoModelForCausalLM.from_pretrained(
            "HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16
        )
        cls.model.config.sliding_window = 256  # hack to enable the use of caches with sliding windows

    @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
    def test_cache_batched(self, cache_implementation):
        """Sanity check: caches' `.update` function expects batched inputs"""
        _skip_on_failed_cache_prerequisites(self, cache_implementation)

        EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]

        inputs = self.tokenizer(
            ["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt"
        )
        inputs = inputs.to(self.model.device)

        gen_out = self.model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=10,
            return_dict_in_generate=True,
            cache_implementation=cache_implementation,
            disable_compile=True,
        )
        # Sanity check: a cache was used
        self.assertIsInstance(gen_out.past_key_values, Cache)
        # Confirm that the output matches expectations
        decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        self.assertListEqual(decoded, EXPECTED_GENERATION)

    @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
    def test_cache_beam_search(self, cache_implementation):
        """
        Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
        (an output sequence contains multiple beam indices).
        """
        _skip_on_failed_cache_prerequisites(self, cache_implementation)
        if cache_implementation == "offloaded_hybrid_chunked":
            # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
            # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
            # different from the other caches.
            self.skipTest("`offloaded_hybrid_chunked` fails this test")

        EXPECTED_GENERATION = [
            "Blue is the color of the sky, and the color of",
            "Blue is the color of the sky, and the second is",
        ]

        inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device)
        gen_out = self.model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=10,
            num_beams=2,
            num_return_sequences=2,
            cache_implementation=cache_implementation,
            disable_compile=True,
            return_dict_in_generate=True,
        )
        # Sanity check: a cache was used
        self.assertIsInstance(gen_out.past_key_values, Cache)
        # At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around
        self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices))
        # Confirm that the output matches expectations
        decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        self.assertListEqual(decoded, EXPECTED_GENERATION)

    @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
    def test_cache_extra_left_padding(self, cache_implementation):
        """Tests that adding extra left-padding does not affect the generation with the cache"""
        _skip_on_failed_cache_prerequisites(self, cache_implementation)

        EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]

        inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device)
        generation_kwargs = {
            "do_sample": False,
            "max_new_tokens": 10,
            "cache_implementation": cache_implementation,
            "disable_compile": True,
        }

        gen_out = self.model.generate(**inputs, **generation_kwargs)
        decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
        self.assertListEqual(decoded, EXPECTED_GENERATION)

        # Now with extra left-padding
        inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32)
        inputs_expanded = inputs_expanded.to(self.model.device)
        self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
        gen_out = self.model.generate(**inputs_expanded, **generation_kwargs)
        decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
        self.assertListEqual(decoded, EXPECTED_GENERATION)


@require_torch_accelerator
class CacheHardIntegrationTest(unittest.TestCase):
    """Hard cache integration tests that require loading different models"""

    def setUp(self):
        # Clears memory before each test. Some tests use large models, which might result in suboptimal torch
        # re-allocation if we run multiple tests in a row without clearing memory.
        cleanup(torch_device, gc_collect=True)

    @classmethod
    def tearDownClass(cls):
        # Clears memory after the last test. See `setUp` for more details.
        cleanup(torch_device, gc_collect=True)

    @slow
    def test_dynamic_cache_hard(self):
        """Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16)
        inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)

        set_seed(0)
        gen_out = model.generate(
            **inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
        )
        decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        # sum of the scores for the generated tokens
        input_length = inputs.input_ids.shape[1]
        score_sum = sum(
            [score[0][gen_out.sequences[0][input_length + idx]] for idx, score in enumerate(gen_out.scores)]
        )

        EXPECTED_GENERATION = (
            "Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have "
            "a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is "
            "covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom "
            "themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, "
            "some are not. They have a lot of different personalities. They can be very independent, or they can be "
            "very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or "
            "they can be very silly. They have a lot of different behaviors. They can be very curious, or they can "
            "be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they "
            "can be very solitary. They can be very active, or they can be very inactive. They can be very "
            "affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can "
            "be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be "
            "very curious, or they can"
        )
        EXPECTED_SCORE_SUM = 11017.4971
        self.assertEqual(decoded[0], EXPECTED_GENERATION)
        self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2)
        self.assertIsInstance(gen_out.past_key_values, DynamicCache)  # sanity check

    @parameterized.expand([("eager"), ("sdpa")])
    @require_torch_accelerator
    @slow
    def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
        """Tests that different cache implementations work well with eager and SDPA inference"""
        EXPECTED_GENERATION = [
            "The best color is the one that is most suitable for the purpose.",
            "We should not undermind the issues at hand, but instead, we should focus on the things",
        ]

        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
        model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen3-4B",
            torch_dtype=torch.bfloat16,
            attn_implementation=attn_implementation,
            device_map="auto",
        )
        inputs = tokenizer(
            ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
        ).to(model.device)
        generation_kwargs = {"do_sample": False, "max_new_tokens": 10, "return_dict_in_generate": True}

        set_seed(0)
        gen_out = model.generate(**inputs, **generation_kwargs)
        decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        with self.subTest(f"{attn_implementation}, dynamic"):
            self.assertListEqual(decoded, EXPECTED_GENERATION)
            self.assertIsInstance(gen_out.past_key_values, DynamicCache)  # sanity check

        set_seed(0)
        gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static", disable_compile=True)
        decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        with self.subTest(f"{attn_implementation}, static, eager"):
            self.assertListEqual(decoded, EXPECTED_GENERATION)
            self.assertIsInstance(gen_out.past_key_values, StaticCache)  # sanity check

        set_seed(0)
        gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static")
        decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
        with self.subTest(f"{attn_implementation}, static, compiled"):
            self.assertListEqual(decoded, EXPECTED_GENERATION)
            self.assertIsInstance(gen_out.past_key_values, StaticCache)  # sanity check

    @require_torch_accelerator
    @slow
    def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
        """Tests that OffloadedCache uses less memory than the default DynamicCache"""
        model_name = "microsoft/Phi-3-mini-4k-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
        device = model.device

        if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
            self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")

        input_text = "Fun fact:"
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        common = {
            "num_beams": 4,
            "num_beam_groups": 2,
            "num_return_sequences": 4,
            "diversity_penalty": 1.0,
            "max_new_tokens": 20,
            "early_stopping": True,
        }
        original = GenerationConfig(**common)
        offloaded = GenerationConfig(cache_implementation="offloaded", **common)

        torch_accelerator_module = backend_torch_accelerator_module(device.type)

        torch_accelerator_module.reset_peak_memory_stats(device)
        model.generate(generation_config=original, **inputs)
        original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
        torch_accelerator_module.reset_peak_memory_stats(device)
        model.generate(generation_config=offloaded, **inputs)
        offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
        self.assertTrue(offloaded_peak_memory < original_peak_memory)

    @require_torch_accelerator
    @slow
    def test_cache_copy(self):
        """Tests that we can manually set a cache, copy, and reuse it for generation"""
        # TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the
        # lazy init of cache layers
        model_name = "microsoft/Phi-3-mini-4k-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device, torch_dtype=torch.bfloat16)

        prompt_cache = StaticCache(
            config=model.config, max_batch_size=1, max_cache_len=1024, device=torch_device, dtype=torch.bfloat16
        )

        INITIAL_PROMPT = "You are a helpful assistant. "
        inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(torch_device)
        # This is the common prompt cached, we need to run forward without grad to be able to copy
        with torch.no_grad():
            prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values

        prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
        responses = []
        for prompt in prompts:
            new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(torch_device)
            past_key_values = copy.deepcopy(prompt_cache)
            outputs = model.generate(
                **new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True
            )
            response = tokenizer.batch_decode(outputs)[0]
            responses.append(response)

        EXPECTED_DECODED_TEXT = [
            "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an "
            "enriching experience that broadens our horizons and allows us to explore the world beyond our comfort "
            "zones. Whether it's a short weekend getaway",
            "You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
            "of France.\n\n\n\n\n\n\n<|endoftext|>",
        ]

        self.assertEqual(responses, EXPECTED_DECODED_TEXT)

    @require_torch_multi_gpu
    def test_data_parallel_dynamic_cache(self):
        """
        Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
        multiple `DynamicCache` in the gather step.
        """

        model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
        tokenizer = AutoTokenizer.from_pretrained(model_repo)

        # w/o DP: batch_size = num_gpu
        # w DP: batch_size = 1 (with num_gpus replicas)
        num_gpus = get_gpu_count()
        model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)

        # w/o DP
        no_parallelism_cache = model(**model_inputs).past_key_values
        self.assertIsInstance(no_parallelism_cache, DynamicCache)

        # w DP
        model = torch.nn.DataParallel(model)
        parallelism_cache = model(**model_inputs).past_key_values
        self.assertIsInstance(parallelism_cache, DynamicCache)

        # Check that the caches are the same
        for layer_idx in range(len(no_parallelism_cache)):
            for kv_idx in range(2):  # 0 = key, 1 = value
                torch.testing.assert_close(
                    actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
                )

    @require_torch_gpu
    def test_static_cache_no_cuda_graph_skips(self):
        """
        Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.

        (? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
        messages are being thrown to stderr?)
        """
        model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
        tokenizer = AutoTokenizer.from_pretrained(model_repo)
        inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)

        # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
        with CaptureStderr() as cap:
            model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
        self.assertNotIn("cuda", cap.err.lower())

    @require_torch_multi_accelerator
    @slow
    @require_read_token
    def test_static_cache_multi_accelerator(self):
        """Regression test for #35164: static cache with multi-accelerator"""

        model_id = "google/gemma-2-2b-it"
        tokenizer = AutoTokenizer.from_pretrained(model_id)

        device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
        num_hidden_layers = 26
        for i in range(num_hidden_layers):
            device_map[f"model.layers.{i}"] = 0 if i < 13 else 1

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype="bfloat16",
            device_map=device_map,
        )
        inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
        _ = model(**inputs)
        _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")

    @require_torch_gpu
    @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
    def test_cache_gptj_model(self, cache_implementation):
        """Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
        _skip_on_failed_cache_prerequisites(self, cache_implementation)

        model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
        pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
        pipe.model.config.sliding_window = (
            256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
        )
        out = pipe(
            "hello world",
            cache_implementation=cache_implementation,
            max_new_tokens=10,
            do_sample=False,
            disable_compile=True,
            return_tensors=True,
        )[0]["generated_token_ids"][-10:]
        EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
        self.assertListEqual(out, EXPECTED_OUTPUT)


@require_torch
class CacheExportIntegrationTest(unittest.TestCase):
    """Cache tests that rely on `torch.export()` and model loading"""

    def test_dynamic_cache_exportability(self):
        model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
        model = model.eval()
        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
        prompt = "What is the best way to debug python script?"
        inputs = tokenizer(prompt, return_tensors="pt")
        attention_mask = inputs.attention_mask
        input_ids = inputs.input_ids

        ep = export_with_dynamic_cache(model, input_ids, attention_mask)
        res = ep.module()(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=DynamicCache(),
            use_cache=True,
        )
        self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
        self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
        self.assertEqual(
            3,
            len(
                [
                    x
                    for x in ep.graph_signature.input_specs
                    if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
                ]
            ),
        )

        past_key_values_eager = DynamicCache()
        res_eager = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values_eager,
            use_cache=True,
        )
        self.assertTrue(torch.allclose(res.logits, res_eager.logits))
        for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
            self.assertTrue(torch.allclose(k1, k2))

        for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
            self.assertTrue(torch.allclose(v1, v2))

    def test_static_cache_exportability(self):
        """
        Tests that static cache works with `torch.export()`
        """
        if not is_torch_greater_or_equal("2.3"):
            self.skipTest(reason="This test requires torch >= 2.3 to run.")

        set_seed(0)
        device = "cpu"
        dtype = "bfloat16"
        cache_implementation = "static"
        attn_implementation = "sdpa"  # Export and ExecuTorch only works for SdpaAttention
        batch_size = 1
        max_cache_len = 1234
        model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=dtype,
            attn_implementation=attn_implementation,
            generation_config=GenerationConfig(
                use_cache=True,
                cache_implementation=cache_implementation,
                max_length=max_cache_len,
                cache_config={
                    "batch_size": batch_size,
                    "max_cache_len": max_cache_len,
                    "device": device,
                },
            ),
        )
        # Check if cache config is passed through correctly
        self.assertEqual(model.generation_config.use_cache, True)
        self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
        self.assertEqual(model.generation_config.max_length, max_cache_len)
        self.assertTrue(model.generation_config.cache_config is not None)
        self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
        self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)

        exported_program = convert_and_export_with_cache(model)

        # Check if the exported model is configured with the `StaticCache` correctly
        n_static_key_caches = n_static_value_caches = 0
        for buffer_name, buffer in exported_program.named_buffers():
            if buffer_name.startswith("key_cache"):
                self.assertTrue(buffer.shape[0] == batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_static_key_caches = n_static_key_caches + 1
            if buffer_name.startswith("value_cache"):
                self.assertTrue(buffer.shape[0] == batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_static_value_caches = n_static_value_caches + 1
        self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
        self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)

        # Export with dynamic shapes using Dim.AUTO
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
        dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
        exported_program = convert_and_export_with_cache(
            model,
            example_input_ids=input_ids,
            dynamic_shapes=dynamic_shapes,
            strict=False,
        )

    def test_hybrid_cache_exportability(self):
        """
        Tests that static cache works with `torch.export()`
        """
        if not is_torch_greater_or_equal("2.6"):
            self.skipTest(reason="This test requires torch >= 2.6 to run.")

        from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

        set_seed(0)
        model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(model_id)
        model.eval()
        self.assertEqual(model.config.use_cache, True)
        self.assertEqual(model.config.cache_implementation, "hybrid")

        # Export + HybridCache
        model.eval()
        max_batch_size = 1
        max_cache_len = 23
        exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
        exported_program = exportable_module.export()
        n_g_key_caches = n_g_value_caches = 0
        for buffer_name, buffer in exported_program.named_buffers():
            if buffer_name.startswith("key_cache"):
                self.assertTrue(buffer.shape[0] == max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_key_caches = n_g_key_caches + 1
            if buffer_name.startswith("value_cache"):
                self.assertTrue(buffer.shape[0] == max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_value_caches = n_g_value_caches + 1
        self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
        self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)

        # Export with dynamic shapes using Dim.AUTO
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
        dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
        exported_program = exportable_module.export(
            input_ids=input_ids,
            dynamic_shapes=dynamic_shapes,
            strict=False,
        )


class SyntheticCacheTest(unittest.TestCase):
    """Tests cache behavior with simple dummy data."""

    def setUp(self):
        """Set up common configuration and cache instances for all tests."""
        self.window_size = 4
        self.max_cache_len = 4
        self.config = Gemma2Config(
            num_hidden_layers=1,
            num_key_value_heads=1,
            num_attention_heads=1,
            head_dim=1,
            hidden_size=1,
            sliding_window=self.window_size,
            sliding_window_pattern=2,  # Default pattern for hybrid sliding
        )

    def test_static_cache_out_of_bounds(self):
        """Test StaticCache raises IndexError for out-of-bounds positions."""
        static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        pos_out_of_bounds = torch.tensor([self.max_cache_len])  # Position >= max_cache_len

        with self.assertRaises(IndexError):
            static_cache.update(
                key_states=torch.tensor([[[[1.0]]]]),
                value_states=torch.tensor([[[[1.0]]]]),
                layer_idx=0,
                cache_kwargs={"cache_position": pos_out_of_bounds},
            )

    def test_static_cache(self):
        """Test StaticCache with manually prefilled states and hardcoded assertions.

        Scenario 1: Fill up to near capacity
        prefill:       [1.0, 2.0, 0.0, 0.0]
        update pos 2:  [1.0, 2.0, 3.0, 0.0]

        Scenario 2: Fill to capacity
        update pos 3:  [1.0, 2.0, 3.0, 4.0]
        """
        # Scenario 1: Fill up to near capacity
        static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
        static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
        static_cache.update(
            key_states=torch.tensor(3.0)[None, None, None, None],
            value_states=torch.tensor(3.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([2])},
        )
        self.assertEqual(
            static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
        )

        # Scenario 2: Fill to capacity
        static_cache.update(
            key_states=torch.tensor(4.0)[None, None, None, None],
            value_states=torch.tensor(4.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([3])},
        )
        self.assertEqual(
            static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
        )

    def test_sliding_window_cache(self):
        """Test SlidingWindowCache with manually prefilled states and hardcoded assertions.

        Scenario 1: Update within window, no slide yet
        prefill:       [1.0, 2.0, 0.0, 0.0]
        update pos 2:  [1.0, 2.0, 3.0, 0.0]

        Scenario 2: Update causing slide
        prefill:       [1.0, 2.0, 3.0, 4.0]
        update pos 4:  [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)

        Scenario 3: Long prompt handling (prompt_len > window_size)
        input:         [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
        result:        [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
        """
        # Scenario 1: Update within window, no slide yet
        sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
        sliding_cache.update(
            key_states=prefill,
            value_states=prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
        )
        sliding_cache.update(
            key_states=torch.tensor(3.0)[None, None, None, None],
            value_states=torch.tensor(3.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
        )
        self.assertEqual(
            sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
            [1.0, 2.0, 3.0, 0.0],
            "SlidingWindowCache Scenario 1 failed",
        )

        # Scenario 2: Update causing slide
        sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
        sliding_cache.update(
            key_states=prefill,
            value_states=prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
        )
        sliding_cache.update(
            key_states=torch.tensor(5.0)[None, None, None, None],
            value_states=torch.tensor(5.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
        )
        self.assertEqual(
            sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
            [2.0, 3.0, 4.0, 5.0],
            "SlidingWindowCache Scenario 2 failed",
        )

        # Scenario 3: Long prompt handling
        sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
        sliding_cache.update(
            key_states=long_prefill,
            value_states=long_prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
        )
        self.assertEqual(
            sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
            [3.0, 4.0, 5.0, 6.0],
            "SlidingWindowCache Scenario 3 failed",
        )

    def test_hybrid_cache_static_mode(self):
        """Test HybridCache in static mode with hardcoded assertions.

        Scenario 1: Static layer behavior
        prefill:       [1.0, 2.0, 0.0, 0.0]
        update pos 2:  [1.0, 2.0, 3.0, 0.0]

        Scenario 2: Fill to capacity
        update pos 3:  [1.0, 2.0, 3.0, 4.0]
        """
        config = copy.deepcopy(self.config)
        config.sliding_window_pattern = 1  # Layer 0 is static (1 % 1 == 0)

        # Scenario 1
        hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
        hybrid_cache_static_mode.update(
            key_states=prefill,
            value_states=prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(4)},
        )
        hybrid_cache_static_mode.update(
            key_states=torch.tensor(3.0)[None, None, None, None],
            value_states=torch.tensor(3.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([2])},
        )
        self.assertEqual(
            hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
            [1.0, 2.0, 3.0, 0.0],
            "HybridCache Static Scenario 1 failed",
        )

        # Scenario 2
        hybrid_cache_static_mode.update(
            key_states=torch.tensor(4.0)[None, None, None, None],
            value_states=torch.tensor(4.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([3])},
        )
        self.assertEqual(
            hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
            [1.0, 2.0, 3.0, 4.0],
            "HybridCache Static Scenario 2 failed",
        )

    def test_hybrid_cache_sliding_mode(self):
        """Test HybridCache in sliding mode with hardcoded assertions.

        Scenario 1: Update within window, no slide yet
        prefill:       [1.0, 2.0, 0.0, 0.0]
        update pos 2:  [1.0, 2.0, 3.0, 0.0]

        Scenario 2: Update causing first slide
        prefill:       [1.0, 2.0, 3.0, 4.0]
        update pos 4:  [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)

        Scenario 3: Update causing subsequent slide
        update pos 5:  [3.0, 4.0, 5.0, 6.0] (shift continues)

        Scenario 4: Long prompt handling (prompt_len > window_size)
        input:         [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
        result:        [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
        """
        # Scenario 1: Update within window, no slide yet
        hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
        hybrid_cache.update(
            key_states=prefill,
            value_states=prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
        )
        hybrid_cache.update(
            key_states=torch.tensor(3.0)[None, None, None, None],
            value_states=torch.tensor(3.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
        )
        self.assertEqual(
            hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
            [1.0, 2.0, 3.0, 0.0],
            "HybridCache Sliding Scenario 1 failed",
        )

        # Scenario 2: Update causing first slide
        hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
        hybrid_cache.update(
            key_states=prefill,
            value_states=prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
        )
        hybrid_cache.update(
            key_states=torch.tensor(5.0)[None, None, None, None],
            value_states=torch.tensor(5.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
        )
        self.assertEqual(
            hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
            [2.0, 3.0, 4.0, 5.0],
            "HybridCache Sliding Scenario 2 failed",
        )

        # Scenario 3: Update causing subsequent slide
        hybrid_cache.update(
            key_states=torch.tensor(6.0)[None, None, None, None],
            value_states=torch.tensor(6.0)[None, None, None, None],
            layer_idx=0,
            cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
        )
        self.assertEqual(
            hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
            [3.0, 4.0, 5.0, 6.0],
            "HybridCache Sliding Scenario 3 failed",
        )

        # Scenario 4: Long prompt handling
        hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
        long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
        hybrid_cache.update(
            key_states=long_prefill,
            value_states=long_prefill,
            layer_idx=0,
            cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
        )
        self.assertEqual(
            hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
            [3.0, 4.0, 5.0, 6.0],
            "HybridCache Sliding Scenario 4 failed",
        )