Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do multi-machine SPMD/FSDPv2 training with TPU? #8492

Open
radna0 opened this issue Dec 13, 2024 · 6 comments
Open

How to do multi-machine SPMD/FSDPv2 training with TPU? #8492

radna0 opened this issue Dec 13, 2024 · 6 comments

Comments

@radna0
Copy link

radna0 commented Dec 13, 2024

❓ Questions and Help

I saw #6362 but there's no example training script found? For example, if I have multiple TPU v3-8 VMs, how would I achieve this with SPMD/FSDPv2?

I'm currently sending the commands to all TPU VMs this way:

python3.10 podrun --include-local -- hostname
@radna0 radna0 changed the title How to do multi-machine spmd training with TPU? How to do multi-machine SPMD/FSDPv2 training with TPU? Dec 13, 2024
@radna0
Copy link
Author

radna0 commented Dec 14, 2024

*Update: I found https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_distributed_checkpoint.md#process-groups
Is there any example script for this?

@radna0
Copy link
Author

radna0 commented Dec 20, 2024

Anybody can help? I'm still stuck on this

@qihqi
Copy link
Collaborator

qihqi commented Jan 4, 2025

Hi,

So what you want to do is to run the same script on all the host at the same time.
Your solution with podrun is a fine way to accomplish this.

Other alternatives include:

  1. Use gcloud command to broadcast the same command to all the host workers:
gcloud compute tpus tpu-vm ssh --zone "$ZONE" "$TPU_NAME" --project "$PROJECT" --worker=all --command="python script.py".

For example: https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/Llama3-70B-PyTorch/GCE/README.md

  1. Use xpk (https://github.com/AI-Hypercomputer/xpk). This tool automates some of these with GKE + docker. So you give it a docker and it will start the docker on all the machines of your cluster. Example: https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/Llama3-70B-PyTorch/XPK/README.md

Feel free to explore other examples in https://github.com/AI-Hypercomputer/tpu-recipes/tree/main repository. We publish ready-to-run examples there.

Thanks!

@radna0
Copy link
Author

radna0 commented Jan 4, 2025

Thank you @qihqi , for suggesting different options.

It seems to me that adding XLA as a backend to DeepSpeed would still be the best path going forward. I already am working on this and have it somewhat running, potentially using pipeline parallelism and other utilities offered by DeepSpeed.

Coming back to this, I'm using podrun because my VMs are mostly v3-8 VMs and in different region, so using gcloud or xpk here seems to not be an option as it only works on TPU pods in a specific region. I can see setting different cluster in different region. But if that's the case, given DeepSpeed is running with XLA, I can achieve the same or even better result with more utilities.

This is why I opened up a request on both PyTorch XLA and DeepSpeed repository. Also most of the examples here are on v4 VMs and above, I could hardly get my hands on them so most of these are not accessible. I could see why most of the support is for those VMs, but common folks like me can only get my hands on v3s or v2s.

@qihqi
Copy link
Collaborator

qihqi commented Jan 6, 2025

because my VMs are mostly v3-8 VMs and in different region, -- one issue with this setup, vs. allocating a v3-16 and getting 2 IP addresses (2 hosts) is that, the 2 hosts you get actually have their devices connected with ICI links. The 2 VMs you allocate separately, especially in different region, doesn't have that and the communication between those 2 hosts will go through the slow data center network (DCN). And thus not recommended.

I see the issue for adding XLA to deepspeed. I am happy to see you are making progress on that. It seems like it can be accomplished with changes to deepspeed only without changing pytorch/xla itself.

@radna0
Copy link
Author

radna0 commented Jan 7, 2025

because my VMs are mostly v3-8 VMs and in different region, -- one issue with this setup, vs. allocating a v3-16 and getting 2 IP addresses (2 hosts) is that, the 2 hosts you get actually have their devices connected with ICI links. The 2 VMs you allocate separately, especially in different region, doesn't have that and the communication between those 2 hosts will go through the slow data center network (DCN). And thus not recommended.

Yeah I think that's part of why I'm only planning to do SPMD or Pipeline Parallelism on one VM or on VMs in the same region. But the main thing is, that setup isn't even possible now.

I see the issue for adding XLA to deepspeed. I am happy to see you are making progress on that. It seems like it can be accomplished with changes to deepspeed only without changing pytorch/xla itself.

Yes I think there is already enough or just minor tweaks needed to add XLA to deepspeed. I believe I can also implement specific ops for xla in deepspeed based on jax if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants