Skip to content

Commit

Permalink
[BART/PyT] Add synchronize for benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
meatybobby committed Feb 13, 2023
1 parent afea561 commit a797214
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions PyTorch/LanguageModeling/BART/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def generate_summaries_or_translations(
results = []
with torch.no_grad():
for batch in tqdm(data_loader):
torch.cuda.synchronize()
t0 = time.time()

summaries = model.generate(
Expand All @@ -180,6 +181,7 @@ def generate_summaries_or_translations(
if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq

torch.cuda.synchronize()
eval_time = time.time() - t0
for i, pred in enumerate(preds):
store_time = eval_time if i == 0 else None #only store latency for element 0 of every batch
Expand Down
2 changes: 2 additions & 0 deletions PyTorch/LanguageModeling/BART/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,11 @@ def generic_train(
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
local_step += 1
torch.cuda.synchronize()
iter_start = time.time()

total_loss, logs = train_one_step(args, trainer, optimizer, scheduler, batch, local_step, scaler)
torch.cuda.synchronize()
train_perf = logs["bs"] * get_world_size() / (time.time() - iter_start)


Expand Down

0 comments on commit a797214

Please sign in to comment.