File size: 981 Bytes
e0be88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import unittest

from transformers.testing_utils import Expectations


class ExpectationsTest(unittest.TestCase):
    def test_expectations(self):
        expectations = Expectations(
            {
                (None, None): 1,
                ("cuda", 8): 2,
                ("cuda", 7): 3,
                ("rocm", 8): 4,
                ("rocm", None): 5,
                ("cpu", None): 6,
                ("xpu", 3): 7,
            }
        )

        def check(value, key):
            assert expectations.find_expectation(key) == value

        # npu has no matches so should find default expectation
        check(1, ("npu", None))
        check(7, ("xpu", 3))
        check(2, ("cuda", 8))
        check(3, ("cuda", 7))
        check(4, ("rocm", 9))
        check(4, ("rocm", None))
        check(2, ("cuda", 2))

        expectations = Expectations({("cuda", 8): 1})
        with self.assertRaises(ValueError):
            expectations.find_expectation(("xpu", None))