tools / run_no_grad.py
patrickvonplaten's picture
finish
c5d43bc
raw
history blame
412 Bytes
#!/usr/bin/env python3
from transformers import BertModel
import torch
import time
model = BertModel.from_pretrained("bert-base-uncased")
model.to("cuda")
input_ids = torch.ones((16, 256), dtype=torch.long)
input_ids = input_ids.to("cuda")
model.requires_grad_(False)
start_time = time.time()
for _ in range(5):
with torch.no_grad():
logits = model(input_ids)
print(time.time() - start_time)