|
|
|
import time |
|
from argparse import ArgumentParser |
|
|
|
import jax |
|
import numpy as np |
|
|
|
from transformers import BertConfig, FlaxBertModel |
|
|
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32") |
|
args = parser.parse_args() |
|
|
|
dtype = jax.numpy.float32 |
|
if args.precision == "bfloat16": |
|
dtype = jax.numpy.bfloat16 |
|
|
|
VOCAB_SIZE = 30522 |
|
BS = 32 |
|
SEQ_LEN = 128 |
|
|
|
|
|
def get_input_data(batch_size=1, seq_length=384): |
|
shape = (batch_size, seq_length) |
|
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32) |
|
token_type_ids = np.ones(shape).astype(np.int32) |
|
attention_mask = np.ones(shape).astype(np.int32) |
|
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} |
|
|
|
|
|
inputs = get_input_data(BS, SEQ_LEN) |
|
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new") |
|
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype) |
|
|
|
|
|
@jax.jit |
|
def func(): |
|
outputs = model(**inputs) |
|
return outputs |
|
|
|
|
|
(nwarmup, nbenchmark) = (5, 100) |
|
|
|
|
|
for _ in range(nwarmup): |
|
func() |
|
|
|
|
|
|
|
start = time.time() |
|
for _ in range(nbenchmark): |
|
func() |
|
end = time.time() |
|
print(end - start) |
|
print(f"Throughput: {((nbenchmark * BS) / (end - start)):.3f} examples/sec") |
|
|