XXX: I concat'ed 2 docs I wrote elsewhere so might need to restructure them into a more coherent doc.
When you have warnings and asserts (or debug prints), it helps a lot to prefix each log with its hostname:rank
python -m torch.distributed.run --role $(hostname -s): --tee 3 --nnodes 1 --nproc_per_node 2 --nnodes 1 torch-distributed-gpu-test.py
Now each log line will be prefixed with [hostname:rank]
Note that the colon :
at the end of --role
entry is important, that's how you get hostname:rank
prefix. But you can add any other separator there, e.g if you use -
, you will end up with hostname-rank
prefix.
If you're in a SLURM environment the above command line becomes:
srun --jobid $SLURM_JOBID bash -c 'python -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \
--master_addr $MASTER_ADDR --master_port $MASTER_PORT \
--role $(hostname -s): --tee 3 \
torch-distributed-gpu-test.py'
Of course adjust your environment variables to match, this was just an example.
Important! Note, that I'm using a single quoted string of commands passed to bash -c
. This way hostname -s
command is delayed until it's run on each of the nodes. If you'd use double quotes above, hostname -s
will get executed on the starting node and then all nodes will get the same hostname as the prefix, which defeats the purpose of using these flags. So if you use double quotes you need to rewrite the above like so:
srun --jobid $SLURM_JOBID bash -c "python -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank \$SLURM_PROCID \
--master_addr $MASTER_ADDR --master_port $MASTER_PORT \
--role \$(hostname -s): --tee 3 \
torch-distributed-gpu-test.py"
$SLURM_PROCID
is escaped too as it needs to be specific to each node and it's unknown during the launch of the slurm job on the main node. So there are 2 \$
escapes in this version of the code.
This prefixing functionality is also super-helpful when one gets the distributed program fail and which often results in interleaved assert messages that are very difficult to interpret. So by grep
ing for one node:rank
string of choice, it's now possible to reconstruct the real error message.
For example, if you get a traceback that looks like:
File "/path/to/training/dataset.py", line 785, in __init__
File "/path/to/training/dataset.py", line 785, in __init__
if self.dataset_proba.sum() != 1:
AttributeError: 'list' object has no attribute 'sum'
if self.dataset_proba.sum() != 1:
File "/path/to/training/dataset.py", line 785, in __init__
File "/path/to/training/dataset.py", line 785, in __init__
if self.dataset_proba.sum() != 1:
if self.dataset_proba.sum() != 1:
AttributeError: 'list' object has no attribute 'sum'
AttributeError: 'list' object has no attribute 'sum'
AttributeError: 'list' object has no attribute 'sum'
and when it's dozens of frames over 8 nodes it can't be made sense of, but the above -tee
+ --role
will generate:
[host1:0] File "/path/to/training/dataset.py", line 785, in __init__
[host1:1] File "/path/to/training/dataset.py", line 785, in __init__
[host1:0] if self.dataset_proba.sum() != 1:
[host1:0]AttributeError: 'list' object has no attribute 'sum'
[host1:1] if self.dataset_proba.sum() != 1:
[host1:2] File "/path/to/training/dataset.py", line 785, in __init__
[host1:3] File "/path/to/training/dataset.py", line 785, in __init__
[host1:3] if self.dataset_proba.sum() != 1:
[host1:2] if self.dataset_proba.sum() != 1:
[host1:1]AttributeError: 'list' object has no attribute 'sum'
[host1:2]AttributeError: 'list' object has no attribute 'sum'
[host1:3]AttributeError: 'list' object has no attribute 'sum'
and you can grep
this output for just one host:rank
prefix, which gives us:
$ grep "[host1:0]" log.txt
[host1:0] File "/path/to/training/dataset.py", line 785, in __init__
[host1:0] if self.dataset_proba.sum() != 1:
[host1:0]AttributeError: 'list' object has no attribute 'sum'
and voila, you can now tell what really happened. And as I mentioned earlier there can be easily a few hundred interleaved assert lines there. I was demo'ing a small example.
Also, if you have just one node, you can just pass -tee 3
and there is no need to pass --role
.
And of course if you're doing debug prints, then to solve this exact issue you can use printflock
.
When using CUDA, failing pytorch programs very often produce a python traceback that makes no sense or can't be acted upon. This is because due to CUDA's async nature - when a CUDA kernel is executed, the program has already moved on and when the error happened the context of the program isn't there. The async functionality is there to make things faster, so that while the GPU is churning some matmul
the program on CPU could already start doing something else.
At other times some parts of the system will actually tell you that they couldn't generate the correct traceback, as in this error:
[E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the
asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/
incomplete data. To avoid this inconsistency, we are taking the entire process down.
There are a few solutions.
If the failure is instant and can be reproduced on CPU (not all programs work on CPU), simply re-rerun it after hiding your GPUs. This is how you do it:
CUDA_VISIBLE_DEVICES="" python my-pytorch-program.py
The env var CUDA_VISIBLE_DEVICES
is used to manually limit the visibility of GPUs to the executed program. So for example if you have 8 gpus and you want to run program1.py with first 4 gpus and program2.py with the remaining 2 gpus you can do:
CUDA_VISIBLE_DEVICES="0,1,2,3" python my-pytorch-program1.py
CUDA_VISIBLE_DEVICES="4,5,6,7" python my-pytorch-program2.py
and the second program won't be the wiser that it's not using GPUs 0-3.
But in the case of debug we are hiding all GPUs, by setting CUDA_VISIBLE_DEVICES=""
.
Now the program runs on CPU and you will get a really nice traceback and will fix the problem in no time.
But, of course, if you your program requires multiple GPUs this won't work. And so here is another solution.
Rerun your program after setting this environment variable:
CUDA_LAUNCH_BLOCKING=1 python my-pytorch-program.py
This variable tells pytorch (or any other CUDA-based program) to turn its async nature off everywhere and now all operations will be synchronous. So when the program crashes you should now get a perfect traceback and you will know exactly what ails your program.
In theory enabling this variable should make everything run really slow, but in reality it really depends on your software. We did the whole of BLOOM-176B training using CUDA_LAUNCH_BLOCKING=1
with Megatron-Deepspeed
](https://github.com/bigscience-workshop/Megatron-DeepSpeed) and had zero slowdown - we had to use it as pytorch was hanging without it and we had no time to figure the hanging out.
So, yes, when you switch from async to sync nature, often it can hide some subtle race conditions, so there are times that a hanging disappears as in the example I shared above. So measure your throughput with and without this flag and sometimes it might actual not only help with getting an in-context traceback but actually solve your problem altogether.
Note: NCCL==2.14.3 coming with pytorch==1.13
hangs when CUDA_LAUNCH_BLOCKING=1
is used. So don't use it with that version of pytorch. The issue has been fixed in nccl>=2.17
which should be included in pytorch==2.0
.
It's not uncommon for a complex pytorch program to segfault and drop a core file. Especially if you're using complex extensions like NCCL.
The corefile is what the program generates when it crashes on a low-level - e.g. when using a python extension - such as a CUDA kernel or really any library that is coded directly in some variant of C or another language and made accessible in python through some binding API. The most common cause of a segfault is when such software accesses memory it has not allocated. For example, a program may try to free memory it hasn't allocated. But there could be many other reasons.
When a segfault event happens Python can't do anything, as the proverbial carpet is pulled out from under its feet, so it can't generate an exception or even write anything to the output.
In these situation one must go and analyse the libC-level calls that lead to the segfault, which is luckily saved in the core file.
If your program crashed, you will often find a file that will look something like: core-python-3097667-6
Before we continue make sure you have gdb
installed:
sudo apt-get install gdb
Now make sure you know the path to the python executable that was used to run the program that crashed. If you have multiple python environment you have to activate the right environment first. If you don't gdb
may fail to unpack the core file.
So typically I'd go:
conda activate my-env
gdb python core-python-3097667-6
- adjust
my-env
to whatever env you use, or instead of conda use whatever way you use to activate your python environment - and perhaps you're using the system-wise python and then you don't need to activate anything. - adjust the name of the core file to the file you have gotten - it's possible that there are many - pick the latest then.
Now gdb
will churn for a bit and will give you a prompt where you type: bt
. We will use an actual core file here:
(gdb) bt
#0 0x0000147539887a9f in raise () from /lib64/libc.so.6
#1 0x000014753985ae05 in abort () from /lib64/libc.so.6
#2 0x000014751b85a09b in __gnu_cxx::__verbose_terminate_handler() [clone .cold.1] () from /lib64/libstdc++.so.6
#3 0x000014751b86053c in __cxxabiv1::__terminate(void (*)()) () from /lib64/libstdc++.so.6
#4 0x000014751b860597 in std::terminate() () from /lib64/libstdc++.so.6
#5 0x000014751b86052e in std::rethrow_exception(std::__exception_ptr::exception_ptr) () from /lib64/libstdc++.so.6
#6 0x000014750bb007ef in c10d::ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() ()
from .../python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so
#7 0x000014750bb04c69 in c10d::ProcessGroupNCCL::workCleanupLoop() ()
from.../python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so
#8 0x000014751b88cba3 in execute_native_thread_routine () from /lib64/libstdc++.so.6
#9 0x000014753a3901cf in start_thread () from /lib64/libpthread.so.0
#10 0x0000147539872dd3 in clone () from /lib64/libc.so.6
and there you go. How do you make sense of it?
Well, you go from the bottom of the stack to the top. You can tell that a clone
call was made in libc
which then called start_thread
in libpthread
and then if you keep going there are a bunch of calls in the torch libraries and finally we can see that the program terminated itself, completing with raise
from libc
which told the Linux kernel to kill the program and create the core file.
This wasn't an easy to understand backtrace.
footnote: Yes, python calls it a traceback and elsewhere it's called a backtrace - it's confusing, but it's more or less the same thing.
Actually I had to ask pytorch devs for help and received:
- PyTorch
ProcessGroup
watchdog thread caught an asynchronous error from NCCL - This error is an
“unhandled system error”
which in this particular case turned out to be an IB-OPA error - The
ProcessGroup
’sWorkCleanUp
thread rethrew the error so that the main process would crash and the user would get notified (otherwise this async error would not surface)
Trust me there are times when even if you're inexperienced the backtrace can give you enough of a hint to where you should look for troubleshooting.
But fear not - most of the time you won't need to understand the traceback. Ideally you'd just attach the core file to your filed Issue. But it can easily be 5GB large. So the developers that will be trying to help you will ask you to generate a gdb
backtrace and now you know how to do that.
I didn't promise it'll be easy, I just showed you where to start.
Now another useful details is that many programs these days run multiple threads. And bt
only shows the main thread of the process. But, often, it can be helpful to see where other threads in the process were when segfault has happened. For that you simply type 2 commands at the (gdb)
prompt:
(gdb) thread apply all bt
(gdb) bt
and this time around you typically will get a massive report, one backtrace per thread.
Similar to py-spy, strace
is a super-useful tool which traces any running application at the low-level system calls - e.g. libC
and alike.
For example, run:
strace python -c "print('strace')"
and you will see everything that is done at the system call level as the above program runs.
But usually it's more useful when you have a stuck program that spins all CPU cores at 100% but nothing happens and you want to see what's it doing. In this situation you simply attached to the running program like so:
strace --pid PID
where you get the PID for example from the output of top
or ps
. Typically I just copy-n-paste the PID of the program that consumes the most CPU - top
usually shows it at the very top of its listing.
Same as py-spy
you may need sudo
perms to attached to an already running process - it all depends on your system setup. But you can always start a program with strace
as I have shown in the original example.
Let's look at a small sub-snippet of the output of strace python -c "print('strace')"
write(1, "strace\n", 7strace
) = 7
Here we can see that a write call was executed on filedescriptor 1
, which almost always is stdout
(stdin
being 0, and stderr
being 2).
If you're not sure what a filedescriptor is pointing to, normally you can tell from strace
's output itself. But you can also do:
ls -l /proc/PID/fd
where PID is the pid of the currently running program you're trying to investigate.
For example, when I run the above while running a pytest test with gpus, I got (partial output):
l-wx------ 1 stas stas 64 Mar 1 17:22 5 -> /dev/null
lr-x------ 1 stas stas 64 Mar 1 17:22 6 -> /dev/urandom
lrwx------ 1 stas stas 64 Mar 1 17:22 7 -> /dev/nvidiactl
lrwx------ 1 stas stas 64 Mar 1 17:22 8 -> /dev/nvidia0
lr-x------ 1 stas stas 64 Mar 1 17:22 9 -> /dev/nvidia-caps/nvidia-cap2
so you can see that a device /dev/null
is open as FD (file descriptor) 5, /dev/urandom
as FD 6, etc.
Now let's go look at another snippet from our strace
run.
access("/etc/ld.so.preload", R_OK) = -1 ENOENT (No such file or directory)
Here it tried to see if file /etc/ld.so.preload
exists, but as we can see it doesn't - this can be useful if some shared library is missing - you can see where it's trying to load it from.
Let's try another one:
openat(AT_FDCWD, "/lib/x86_64-linux-gnu/libpthread.so.0", O_RDONLY|O_CLOEXEC) = 3
read(3, "\177ELF\2\1\1\0\0\0\0\0\0\0\0\0\3\0>\0\1\0\0\0\0\0\0\0\0\0\0\0"..., 832) = 832
newfstatat(3, "", {st_mode=S_IFREG|0644, st_size=21448, ...}, AT_EMPTY_PATH) = 0
mmap(NULL, 16424, PROT_READ, MAP_PRIVATE|MAP_DENYWRITE, 3, 0) = 0x7f8028807000
mmap(0x7f8028808000, 4096, PROT_READ|PROT_EXEC, MAP_PRIVATE|MAP_FIXED|MAP_DENYWRITE, 3, 0x1000) = 0x7f8028808000
mmap(0x7f8028809000, 4096, PROT_READ, MAP_PRIVATE|MAP_FIXED|MAP_DENYWRITE, 3, 0x2000) = 0x7f8028809000
mmap(0x7f802880a000, 8192, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_FIXED|MAP_DENYWRITE, 3, 0x2000) = 0x7f802880a000
close(3)
here we can see that it opens /lib/x86_64-linux-gnu/libpthread.so.0
and assigns it FD 3, it then reads 832 chars from FD 3, (we can also see that the first chars are ELF - which stands for a shared library format), then memory maps it and closes that file.
In this following example, we see a python cached file is opened, its filepointer is moved to 0, and then it's read and closed.
openat(AT_FDCWD, "/home/stas/anaconda3/envs/py38-pt113/lib/python3.8/__pycache__/abc.cpython-38.pyc", O_RDONLY|O_CLOEXEC) = 3
fstat(3, {st_mode=S_IFREG|0664, st_size=5329, ...}) = 0
lseek(3, 0, SEEK_CUR) = 0
lseek(3, 0, SEEK_CUR) = 0
fstat(3, {st_mode=S_IFREG|0664, st_size=5329, ...}) = 0
brk(0x23bf000) = 0x23bf000
read(3, "U\r\r\n\0\0\0\0\24\216\177c\211\21\0\0\343\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"..., 5330) = 5329
read(3, "", 1) = 0
close(3)
It's important to notice that file descriptors are re-used, so we have seen the same FD 3 twice, but each time it was open to a different file.
If your program is for example trying to reach to the Internet, you can also tell these calls from strace
as the program would be reading from a socket file descriptor.
So let's run an example on a program that downloads files from the HF hub:
strace python -c 'import sys; from transformers import AutoConfig; AutoConfig.from_pretrained(sys.argv[1])' t5-small
here is some relevant to this discussion snippet:
socket(AF_INET6, SOCK_STREAM|SOCK_CLOEXEC, IPPROTO_TCP) = 3
setsockopt(3, SOL_TCP, TCP_NODELAY, [1], 4) = 0
ioctl(3, FIONBIO, [1]) = 0
connect(3, {sa_family=AF_INET6, sin6_port=htons(443), sin6_flowinfo=htonl(0), inet_pton(AF_INET6, "2600:1f18:147f:e850:e203:c458:10cd:fc3c
", &sin6_addr), sin6_scope_id=0}, 28) = -1 EINPROGRESS (Operation now in progress)
poll([{fd=3, events=POLLOUT|POLLERR}], 1, 10000) = 1 ([{fd=3, revents=POLLOUT}])
getsockopt(3, SOL_SOCKET, SO_ERROR, [0], [4]) = 0
[...]
write(3, "\26\3\3\0F\20\0\0BA\4\373m\244\16\354/\334\205\361j\225\356\202m*\305\332\275\251\17J"..., 126) = 126
read(3, 0x2f05c13, 5) = -1 EAGAIN (Resource temporarily unavailable)
poll([{fd=3, events=POLLIN}], 1, 9903) = 1 ([{fd=3, revents=POLLIN}])
read(3, "\24\3\3\0\1", 5) = 5
read(3, "\1", 1) = 1
read(3, "\26\3\3\0(", 5) = 5
read(3, "\0\0\0\0\0\0\0\0\344\v\273\225`\4\24m\234~\371\332%l\364\254\34\3472<\0356s\313"..., 40) = 40
ioctl(3, FIONBIO, [1]) = 0
poll([{fd=3, events=POLLOUT}], 1, 10000) = 1 ([{fd=3, revents=POLLOUT}])
write(3, "\27\3\3\1.\0\374$\361\217\337\377\264g\215\364\345\256\260\211$\326pkR\345\276,\321\221`-"..., 307) = 307
ioctl(3, FIONBIO, [1]) = 0
read(3, 0x2ef7283, 5) = -1 EAGAIN (Resource temporarily unavailable)
poll([{fd=3, events=POLLIN}], 1, 10000) = 1 ([{fd=3, revents=POLLIN}])
You can see where that again it uses FD 3 but this time it opens a INET6 socket instead of a file. You can see that it then connects to that socket, polls, reads and writes from it.
There are many other super useful understandings one can derive from using this tool.
BTW, if you don't want to scroll up-down, you can also save the output to a file:
strace -o strace.txt python -c "print('strace')"
While the methodologies found in this article were developed while working with multi-node multi-gpu pytorch-based training, they, of course, can help with any multi-process multi-node Python programs.
Try to use the following script torch-distributed-gpu-test.py to diagnose the situation.
This will help primarily with discovering network-related issues. And also to quickly understand how multi-gpu communications work.
For code-related issues read the rest of this document.
First do pip install py-spy
.
Now you can attach to each process with:
py-spy dump -n -p PID
and it will tell you where the process hangs (very often it's a nccl collective function or a barrier
).
PID
is the process id of the hanging python process.-n
is useful if you want to see strack traces from python extensions written in C, C++, etc., as the program may hang in one of the extensions- you may need to add
sudo
before the command - for more details see this note.
Here is an example of such a stack trace:
Thread 835995 (active): "MainThread"
broadcast (torch/distributed/distributed_c10d.py:1191)
_aggregate_total_loss (deepspeed/runtime/pipe/engine.py:540)
train_batch (deepspeed/runtime/pipe/engine.py:330)
train_step (megatron/training.py:436)
train (megatron/training.py:851)
pretrain (megatron/training.py:187)
<module> (pretrain_gpt.py:239)
The very first line is where the program is stuck.
Now, how do you do it for multiple processes. Doing it one-by-one is too slow. So let's do it at once.
If the launch command was python
, what you do is:
pgrep -P $(pgrep -o python) | xargs -I {} py-spy dump --pid {}
if deepspeed
:
pgrep -P $(pgrep -o deepspeed) | xargs -I {} py-spy dump --pid {}
for accelerate
:
pgrep -P $(pgrep -o accelerate) | xargs -I {} py-spy dump --pid {}
you get the idea.
This particular approach will only analyse the main processes and not various other sub-processes/threads spawned by these processes. So if you have 8 gpus and 8 processes, the above will generate 8 stack traces.
If you want all processes and their subprocesses, then you'd just run:
pgrep -f python | xargs -I {} py-spy dump --pid {}
(and as before replace python
with the name of the launcher program if it's not python
)
What if you have multiple nodes?
You can of course ssh
to each node interactively and dump the stack traces.
If you're using the SLURM environment you can use srun
to do it on all nodes for you.
Now in another console get the SLURM_JOBID
(or get it from salloc
log):
squeue -u `whoami` -o "%.16i %.9P %.26j %.8T %.10M %.8l %.6D %.20S %R"
Now use the following srun
command after adjusting jobid with SLURM_JOBID
from the outcome of the command above this sentence:
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 --output=trace-%N.out sh -c 'ps aux | grep python | egrep -v "grep|srun" | grep `whoami` | awk "{print \$2}" | xargs -I {} py-spy dump --native --pid {}' || echo "failed"
Notes:
- One must use
--gres=gpu:0
for the monitorsrun
or otherwise it will block until the mainsrun
(the one running the training) exits. - Each node will generate its unique log file named
trace-nodename.out
- so this would help to identify which node(s) are problematic. You can remove--output=trace-%N.out
if you want it all being dumped to stdout - In some SLURM versions you may also need to add
--overlap
- In some SLURM versions the jobid might not match that of reported in
squeue
, so you have to get the correctSLURM_JOB_ID
from the logs of the job you're trying to "attach" to - i.e. yoursrun
job that allocated the GPUs. - Sometimes
bash
doesn't work, butsh
does. I think it has to do with what dot files getsource
d - You might need to also activate a custom python environment, which you can do like so:
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 --output=trace-%N.out sh -c 'conda activate myenvname; ps auxc | ... ' || echo "failed"
or you can do it inside ~/.bashrc
or whatever shell's rc file you decide to use.
As mentioned before if you want just the main processes you'd use this instead:
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 --output=trace-%N.out sh -c 'pgrep -P $(pgrep -o python) | xargs -I {} py-spy dump --pid {}' || echo "failed"
Adjust python
if need be as explained in the multi-gpu section above.
The previous longer command will deliver traces for all python processes.
If you're not getting anything, start with the basic debug like:
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 --output=trace-%N.out sh -c 'date'
once you know you're talking to all the nodes, then you can progressively unravel the depth of calls, as in:
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 sh -c 'date'
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 sh -c 'pgrep -o python'
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 sh -c 'pgrep -P $(pgrep -o python) '
srun --jobid=2180718 --gres=gpu:0 --nodes=40 --tasks-per-node=1 sh -c 'pgrep -P $(pgrep -o python) | xargs -I {} py-spy dump --pid {}'
and at each stage check that the output makes sense - e.g. the 2nd and 3rd call you should be getting the PIDs of the processes.
The following notes require pip install deepspeed
.
In one SLURM environment I also attempted using pdsh
via ds_ssh
, but somehow I wasn't able to run py-spy
remotely - the main issue was that remote ssh
command wasn't giving the same env as when I was logged in interactively via ssh
. But if you have sudo
access on the compute nodes then you could do:
First prepare hostfile
:
function makehostfile() {
perl -e '$slots=split /,/, $ENV{"SLURM_STEP_GPUS"};
$slots=8 if $slots==0; # workaround 8 gpu machines
@nodes = split /\n/, qx[scontrol show hostnames $ENV{"SLURM_JOB_NODELIST"}];
print map { "$b$_ slots=$slots\n" } @nodes'
}
makehostfile > hostfile
Adapt $slots
to the number of gpus per node. You may have to adapt this script if your scontrol
produces a different output.
Now run the py-spy
extraction command over all participating nodes:
ds_ssh -f hostfile "source ~/.pdshrc; ps aux | grep python | grep -v grep | grep `whoami` | awk '{print \$2}' | xargs -I {} sudo py-spy dump --pid {} "
The hanging could be happening at the network level. NCCL_DEBUG=INFO
can help here.
Run the script with NCCL_DEBUG=INFO
env var and try to study the outcome for obvious errors. It will tell you which device it's using, e.g.:
DeepWhite:21288:21288 [0] NCCL INFO NET/Socket : Using [0]enp67s0:192.168.50.21<0>
So it's using interface enp67s0
over 192.168.50.21
Is your 192.168.50.21
firewalled? or is it somehow a misconfigured network device?
Does it work if you use a loopback device 127.0.0.1
?
NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=lo python -m torch.distributed.run --nproc_per_node 4 --nnodes 1 torch-distributed-gpu-test.py
if not, see what other local network devices you have via ifconfig
- try that instead of lo
if any.
It's currently using enp67s0
in the above example.
You can also try to see if only some GPUs fail
For example, does it work if you use the first 2 or the last 2 gpus:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.run --nproc_per_node 2 --nnodes 1 torch-distributed-gpu-test.py
then the 2nd pair:
CUDA_VISIBLE_DEVICES=2,3 python -m torch.distributed.run --nproc_per_node 2 --nnodes 1 torch-distributed-gpu-test.py
Now what happens when the training doesn't just hang, but the hanging process stops responding? e.g. this happens when there is a serious hardware issue. But what if it is recurrent and py-spy
won't help here, since it won't be able to attach to a process that is not responding.
So next came the idea of tracing all calls like one does with strace(1)
, I researched python calls tracing facilities and have discovered that python has a trace
sub-system.
The following code will trace all python calls and log them to the console and into a dedicated per process log file, via a custom Tee
module I added.
This then can help to understand where some processes stopped responding, since we will have the log of the last call and all the previous calls before it went unresponsive.
$ cat train.py
[...]
def main():
# [...]
train()
import re
class Tee:
"""
A helper class to tee print's output into a file.
Usage:
sys.stdout = Tee(filename)
"""
def __init__(self, filename):
self.stdout = sys.stdout
self.file = open(filename, "a")
def __getattr__(self, attr):
return getattr(self.stdout, attr)
def write(self, msg):
self.stdout.write(msg)
self.file.write(msg)
self.file.flush()
def flush(self):
self.stdout.flush()
self.file.flush()
if __name__ == "__main__":
import sys
import trace
import socket
import os
# enable the trace
if 0:
cwd = os.path.realpath('.')
pid = os.getpid()
hostname = socket.gethostname()
local_rank = int(os.environ["LOCAL_RANK"])
trace_output_file = f"{cwd}/trace-{hostname}-{local_rank}-{pid}.txt"
# create a Trace object, telling it what to ignore, and whether to
# do tracing or line-counting or both.
tracer = trace.Trace(
ignoredirs=[sys.prefix, sys.exec_prefix],
trace=1,
count=1,
timing=True,
)
# run the new command using the given tracer
sys.stdout = Tee(trace_output_file)
tracer.run('main()')
else:
main()
This code doesn't require any special handing other than enabling the trace by changing if 0
to if 1
.
If you don't set ignoredirs
, this will now dump all python calls. Which means expect a lot of GBs of data logged, especially if you have hundreds of GPUs.
Of course, you don't have to start tracing from main
- if you suspect a specific are you can start tracing there instead and it'll be much faster and less data to save.
I wish I could tell trace
which packages to follow, but alas it only supports dirs to ignore, which is much more difficult to set, and thus you end up with a lot more data than needrf. But still this is a super useful tool for debugging hanging processes.
Also, your code will now run much much slower and the more packages you trace the slower it will become.
As Trace
proved to provide very limited usability when debugging a complex multi-node multi-hour run crash, I have started on working on a better version of the trace
You can find it here: NicerTrace
I added multiple additional flags to the constructor and made the output much more useful. You fill find a full working example in that same file, just run:
python NicerTrace.py
and you should see:
trace/NicerTrace.py:1 <module>
0:00:00 <string>: 1: trace/NicerTrace.py:185 main
0:00:00 NicerTrace.py: 186: img = Image.new("RGB", (4, 4))
PIL.Image:2896 new
0:00:00 Image.py: 2912: _check_size(size)
PIL.Image:2875 _check_size
0:00:00 Image.py: 2883: if not isinstance(size, (list, tuple)):
0:00:00 Image.py: 2886: if len(size) != 2:
0:00:00 Image.py: 2889: if size[0] < 0 or size[1] < 0:
as you will see in the example I set:
packages_to_include=["PIL"],
so it'll trace PIL
plus anything that is not under site-packages
. If you need to trace another package, just add it to that list.
This is a very fresh work-in-progress package, so it's evolving as we are trying to make it help us resolve a very complex crashing situation.
When the per-node-rank trace files has been generated the following might be helpful to quickly analyse the situation:
- grep for a specific match and also print the file and line number where it was found:
grep -n "backward" trace*
- show
tail -1
of all trace files followed by the name of each file:
find . -name "trace*" -exec sh -c 'echo "$1: $(tail -3 "$1")"' _ {} \;
- or similar to the above, but print 5 last lines with the leading filename and some vertical white space for an easier reading:
find . -name "trace*" -exec sh -c 'echo; echo $1; echo "$(tail -5 "$1")"' _ {} \;
- count how many times grep matched a given pattern in each ifle and print the matched file (in this example matching the pattern
backward
):
find . -name "trace*" -exec sh -c 'echo "$1: $(grep "backward" $1 | wc -l)"' _ {} \;
Now once you discovered where the hanging happens to further understand why this is happening, a debugger would ideally be used, but more often than not debugging multi-process (multi-node) issues can be very difficult.
In such situations a good old print
works. You just need to add some debug prints before the calls where things hang, things that would help understand what lead to the deadlock. For example, some barrier
was missing and one or a few processes skipped some code and while the rest of processes are still blocking waiting for everybody to send some data (for example in NCCL collective functions like gather
or reduce
).
You of course, want to prefix each print with the rank of the process so that you could tell which is which. For example:
import torch.distributed as dist
print(f"{dist.get_rank()}: passed stage 0")
What you will quickly discover is that if you have multiple GPUs these prints will be badly interleaved and you will have a hard time making sense of the debug data. So let's fix this. We are going to override print
with a custom version of the same, but which uses flock
to ensure that only one process can write to stdout at the same time.
The helper module printflock.py
is included here. To activate it just run this at the top of the module you're debugging:
from printflock import printflock as print
and now all your print
calls in that module will magically be non-iterleaved. You can of course, just use printflock
directly:
from printflock import printflock
import torch.distributed as dist
printflock(f"{dist.get_rank()}: passed stage 0")
Code loops can be tricky to debug in hanging scenarios. If you have code like the following:
for i, d in enumerate(data):
some_hanging_call(d)
it's possible that one process hangs in the first iteration, and another process in the second iteration, which makes things very confusing. But the stack trace won't give such indication, as the line numbers would be the same, even though the processes aren't in the same place code progression-wise.
In such situations unroll the loop to be:
d_iter = iter(data)
some_hanging_call(next(d_iter)
some_hanging_call(next(d_iter)
and now when you run py-spy
the line numbers will be correct. The processes hanging in the first iteration will report the first some_hanging_call
and those in the second iteration in the second call - as each now has its own line.
Some AMD users may need to Disable IOMMU