#!/usr/bin/env python3 | |
from transformers import BertModel | |
import torch | |
import time | |
model = BertModel.from_pretrained("bert-base-uncased") | |
model.to("cuda") | |
input_ids = torch.ones((16, 256), dtype=torch.long) | |
input_ids = input_ids.to("cuda") | |
model.requires_grad_(False) | |
start_time = time.time() | |
for _ in range(5): | |
with torch.no_grad(): | |
logits = model(input_ids) | |
print(time.time() - start_time) | |