-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do multi-machine SPMD/FSDPv2 training with TPU? #8492
Comments
*Update: I found https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_distributed_checkpoint.md#process-groups |
Anybody can help? I'm still stuck on this |
Hi, So what you want to do is to run the same script on all the host at the same time. Other alternatives include:
For example: https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/Llama3-70B-PyTorch/GCE/README.md
Feel free to explore other examples in https://github.com/AI-Hypercomputer/tpu-recipes/tree/main repository. We publish ready-to-run examples there. Thanks! |
Thank you @qihqi , for suggesting different options. It seems to me that adding XLA as a backend to DeepSpeed would still be the best path going forward. I already am working on this and have it somewhat running, potentially using pipeline parallelism and other utilities offered by DeepSpeed. Coming back to this, I'm using This is why I opened up a request on both PyTorch XLA and DeepSpeed repository. Also most of the examples here are on v4 VMs and above, I could hardly get my hands on them so most of these are not accessible. I could see why most of the support is for those VMs, but common folks like me can only get my hands on v3s or v2s. |
I see the issue for adding XLA to deepspeed. I am happy to see you are making progress on that. It seems like it can be accomplished with changes to deepspeed only without changing pytorch/xla itself. |
Yeah I think that's part of why I'm only planning to do SPMD or Pipeline Parallelism on one VM or on VMs in the same region. But the main thing is, that setup isn't even possible now.
Yes I think there is already enough or just minor tweaks needed to add XLA to deepspeed. I believe I can also implement specific ops for xla in deepspeed based on jax if needed. |
❓ Questions and Help
I saw #6362 but there's no example training script found? For example, if I have multiple TPU v3-8 VMs, how would I achieve this with SPMD/FSDPv2?
I'm currently sending the commands to all TPU VMs this way:
The text was updated successfully, but these errors were encountered: