Skip to content

Commit

Permalink
fp16 for data parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Mar 19, 2020
1 parent 0c74fe6 commit 0a258d3
Show file tree
Hide file tree
Showing 24 changed files with 10 additions and 7 deletions.
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified setup.py
100644 → 100755
Empty file.
Empty file modified tools/bubbles.pdf
100644 → 100755
Empty file.
Empty file modified tools/generate-pipeline-schedule.cc
100644 → 100755
Empty file.
Empty file modified tools/generate-schedule-split-backward.cc
100644 → 100755
Empty file.
Empty file modified varuna/__init__.py
100644 → 100755
Empty file.
Empty file modified varuna/generate_schedule.cc
100644 → 100755
Empty file.
Empty file modified varuna/profile.py
100644 → 100755
Empty file.
17 changes: 10 additions & 7 deletions varuna/varuna.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def __init__(self, batches, model, config, schedule, optimizer):
self.model = model
self.partitioned_model = self.model.module if self.data_parallel else self.model
self.device = config["device"]
self.world_size = self.partitions
self.schedule = schedule
self.fp16 = config["fp16"]

Expand Down Expand Up @@ -321,13 +320,13 @@ def spawn_recieve_workers(self):
self.acts_recieve_thread.daemon=True
self.acts_recieve_thread.start()

if self.stage < self.world_size-1:
if self.stage < self.partitions-1:
self.grads_recieve_thread = Thread(target=self.grads_reciever, args=())
self.grads_recieve_thread.daemon=True
self.grads_recieve_thread.start()

def spawn_send_workers(self):
if self.stage < self.world_size-1:
if self.stage < self.partitions-1:
self.acts_send_thread = Thread(target=self.acts_sender, args=())
self.acts_send_thread.daemon=True
self.acts_send_thread.start()
Expand Down Expand Up @@ -436,7 +435,6 @@ def recv(grads = False):

def worker(self, task, grad_mode, inputs_as_dict):
""" Main body of worker loop """
world_size = self.world_size

if task == 0:
torch.set_grad_enabled(grad_mode)
Expand Down Expand Up @@ -471,11 +469,16 @@ def worker(self, task, grad_mode, inputs_as_dict):
self.loss = output[0] if isinstance(output,tuple) else output

else:
if self.stage != world_size-1:
if self.stage != self.partitions-1:
grads = torch.ones(self.loss.size(), dtype = torch.float32).to(self.device)
if self.fp16:
with amp.scale_loss(self.loss, self.optimizer, delay_overflow_check=True) as scaled_loss:
# if dist.get_rank() == 1:
# print("SIZE",self.loss[0][0][0])
with amp.scale_loss(self.loss, self.optimizer) as scaled_loss:
scaled_loss.backward(grads)
if dist.get_rank() == 1:
baseModule = self.model.module if not self.data_parallel else self.model.module.module
baseModule.bert.encoder.layer[7].attention.self.query.weight.grad[0][0] = float('inf')
# self.optimizer.backward(self.loss, grads=grads)
# self.loss.backward(grads)
else:
Expand All @@ -487,7 +490,7 @@ def worker(self, task, grad_mode, inputs_as_dict):
self.average_loss += self.loss.item()

if self.fp16:
with amp.scale_loss(self.loss, self.optimizer, delay_overflow_check=False) as scaled_loss:
with amp.scale_loss(self.loss, self.optimizer, delay_overflow_check=False, last_partition=False) as scaled_loss:
scaled_loss.backward()
# self.optimizer.backward(self.loss)
else:
Expand Down
Empty file modified varuna_for_bert/README.md
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/bert_original_worker.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/check_running.sh
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/launch_bert.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/launch_bert_with_morphing.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/modeling_bert.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/profile-2.csv
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/profile_bert.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/setup.sh
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/signal_remote.sh
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/start_remote.sh
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/utils_squad.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/utils_squad_evaluate.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/varuna_worker.py
100644 → 100755
Empty file.
Empty file modified varuna_for_bert/varuna_worker_with_morphing.py
100644 → 100755
Empty file.

0 comments on commit 0a258d3

Please sign in to comment.