File size: 412 Bytes
c5d43bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
#!/usr/bin/env python3
from transformers import BertModel
import torch
import time
model = BertModel.from_pretrained("bert-base-uncased")
model.to("cuda")
input_ids = torch.ones((16, 256), dtype=torch.long)
input_ids = input_ids.to("cuda")
model.requires_grad_(False)
start_time = time.time()
for _ in range(5):
with torch.no_grad():
logits = model(input_ids)
print(time.time() - start_time)
|