File size: 981 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import unittest
from transformers.testing_utils import Expectations
class ExpectationsTest(unittest.TestCase):
def test_expectations(self):
expectations = Expectations(
{
(None, None): 1,
("cuda", 8): 2,
("cuda", 7): 3,
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
("xpu", 3): 7,
}
)
def check(value, key):
assert expectations.find_expectation(key) == value
# npu has no matches so should find default expectation
check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))
check(4, ("rocm", None))
check(2, ("cuda", 2))
expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))
|