-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
99 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Running Distributed PyTorch | ||
|
||
> **WORK IN PROGRESS** | ||
Training ML models with PyTorch is a resource heavy operation. Our GPU architecture is not great for this application, with one lightweight GPU per node we frequently do not have enough GPU memory for many training operations. | ||
|
||
Distributing your operation across multiple nodes _may_ help alleviate this problem. These examples show one way of accessing GPUs across multiple nodes within our Slurm environment | ||
|
||
> NOTE: this example doesn't do any training, but only exists as a POC for distributed GPU operations. I hope to improve this example in the future | ||
## Use | ||
|
||
The submit script contains the necessary options, so `sbatch ./hello.sh` will run the example | ||
|
||
## Discussion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python3 | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
import torch.distributed as dist # <- this | ||
dist.init_process_group(backend="nccl") # <- and this | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.hidden_layer = nn.Linear(1, 1) | ||
self.hidden_layer.weight = torch.nn.Parameter(torch.tensor([[1.58]])) | ||
self.hidden_layer.bias = torch.nn.Parameter(torch.tensor([-0.14])) | ||
|
||
self.output_layer = nn.Linear(1, 1) | ||
self.output_layer.weight = torch.nn.Parameter(torch.tensor([[2.45]])) | ||
self.output_layer.bias = torch.nn.Parameter(torch.tensor([-0.11])) | ||
|
||
def forward(self, x): | ||
x = torch.sigmoid(self.hidden_layer(x)) | ||
x = torch.sigmoid(self.output_layer(x)) | ||
return x | ||
|
||
|
||
net = Net() | ||
print(f"network topology: {net}") | ||
|
||
print(f"w_l1 = {round(net.hidden_layer.weight.item(), 4)}") | ||
print(f"b_l1 = {round(net.hidden_layer.bias.item(), 4)}") | ||
print(f"w_l2 = {round(net.output_layer.weight.item(), 4)}") | ||
print(f"b_l2 = {round(net.output_layer.bias.item(), 4)}") | ||
|
||
# run input data forward through network | ||
input_data = torch.tensor([0.8]) | ||
output = net(input_data) | ||
print(f"a_l2 = {round(output.item(), 4)}") | ||
|
||
# backpropagate gradient | ||
target = torch.tensor([1.]) | ||
criterion = nn.MSELoss() | ||
loss = criterion(output, target) | ||
net.zero_grad() | ||
loss.backward() | ||
|
||
# update weights and biases | ||
optimizer = optim.SGD(net.parameters(), lr=0.1) | ||
optimizer.step() | ||
|
||
print(f"updated_w_l1 = {round(net.hidden_layer.weight.item(), 4)}") | ||
print(f"updated_b_l1 = {round(net.hidden_layer.bias.item(), 4)}") | ||
print(f"updated_w_l2 = {round(net.output_layer.weight.item(), 4)}") | ||
print(f"updated_b_l2 = {round(net.output_layer.bias.item(), 4)}") | ||
|
||
output = net(input_data) | ||
print(f"updated_a_l2 = {round(output.item(), 4)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=multinode-example | ||
#SBATCH --nodes=4 | ||
#SBATCH --ntasks=4 | ||
#SBATCH --gpus-per-task=1 | ||
#SBATCH --cpus-per-task=4 | ||
|
||
# shamelessly copied from https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/slurm/sbatch_run.sh | ||
|
||
set -e | ||
module load PyTorch | ||
set -x | ||
|
||
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) | ||
nodes_array=($nodes) | ||
head_node=${nodes_array[0]} | ||
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) | ||
|
||
echo Node IP: $head_node_ip | ||
export LOGLEVEL=INFO | ||
|
||
srun torchrun --nnodes 4 \ | ||
--nproc_per_node 1 \ | ||
--rdzv_id ${RANDOM} \ | ||
--rdzv_backend c10d \ | ||
--rdzv_endpoint ${head_node_ip}:29500 \ | ||
./hello.py |