Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run XLA with CPU offloaded models #8049

Open
radna0 opened this issue Sep 23, 2024 · 12 comments
Open

How to run XLA with CPU offloaded models #8049

radna0 opened this issue Sep 23, 2024 · 12 comments

Comments

@radna0
Copy link

radna0 commented Sep 23, 2024

❓ Questions and Help

How do you run models that are offloaded to the CPU, Trying to work with enable_sequential_cpu_offload or enable_model_cpu_offload, when running torch_xla.sync()/xm.mark_step() , the graph seems to not exclude such factor, and in turn takes much more memory than when only running the model on CPU. For example, reportedly running maximum at 25GB on the CPU but takes up 170GB on XLA devices, this is tested with EasyAnimate V4 model generating a 960x1680 24fps video. If needed, I can provide code if this has not been implemented.

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Compilation failure: Aborting compilation early because it's unlikely to have enough device memory. Requires 170.73G, has 14.71G available. If more detailed logging is desired, set --xla_tpu_impure_oom_fast_exit_threshold=-1

@miladm
Copy link
Collaborator

miladm commented Sep 23, 2024

Thanks for your question @radna0!
This is a feature we've like to onboard for PyTorch/XLA users. Can you share the use case you are running into as code / repro?

Certainly, we appreciate having your code contribution on this topic.

cc @JackCaoG to add other potential context

@JackCaoG
Copy link
Collaborator

all of the existing pytorch way of offloading things to CPU probablly won't work on XLA. Reason being that most of those solution was trying to move the tensor to cpu during forward and then take it back in the backward. However in the PyTorch/XLA case we are only doing tracing when running the python modeling code.

In order to support real cpu offloading, we need to annotate tensors we want to be offload to CPU in the HLO graph we produced and implement the necessary runtime logics to do this loading and offloading. This feature we have in our roadmap but we don't have any timeline yet.

@radna0
Copy link
Author

radna0 commented Sep 24, 2024

Hi @miladm @JackCaoG, appreciate it!

Here's the code repo to clone/reproduce from

git clone -b TPU https://github.com/radna0/EasyAnimate.git

you can follow this guide if you wish to use docker.

# pull image
docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate

# enter image
docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate

# clone code
git clone -b TPU https://radna0:[email protected]/radna0/EasyAnimate.git

# enter EasyAnimate's dir
cd EasyAnimate

# download weights
mkdir models/Diffusion_Transformer
mkdir models/Motion_Module
mkdir models/Personalized_Model

wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Diffusion_Transformer/EasyAnimateV4-XL-2-InP.tar.gz -O models/Diffusion_Transformer/EasyAnimateV4-XL-2-InP.tar.gz

cd models/Diffusion_Transformer/
tar -zxvf EasyAnimateV4-XL-2-InP.tar.gz
cd ../../

Or if you use your own environment, I believe you just need torch, torch_xla and torchvision, which I installed like so

# Pytorch XLA
sudo pip uninstall torch torch_xla torchvision -y
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

You can then review the settings of the video generation at predict_t2v.py, and run it via:

python3.10 predict_t2v.py

It should return an error like following

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Compilation failure: Aborting compilation early because it's unlikely to have enough device memory. Requires 170.73G, has 14.71G available. If more detailed logging is desired, set --xla_tpu_impure_oom_fast_exit_threshold=-1

The pipeline used, you can check at

easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py

@JackCaoG
Copy link
Collaborator

I guess my suggestion is not to use other cpu offload methods because it is unlikely to work(given the reason I suggested above). We can spend time and figure out why this particular one won't work but I felt like it is not the best use of our time.

@radna0
Copy link
Author

radna0 commented Sep 25, 2024

Hmm, that's true. Do you have any workarounds in mind that could work in the meantime to handle this? From my understanding, implementing CPU offloading at the HLO graph level would require a lot more work and time for your team, though having some form of CPU offloading could significantly benefit anyone using TPUs for memory-intensive models like this.

@JackCaoG
Copy link
Collaborator

couldn't really think of anything that's useful in real world case. You can try to partiton graphs into smaller ones and move the result back to CPU and then move to devices when it is needed, but it will likely not work very well because

  1. more fragmented graph usually means slower compute and less oppounity to fuse ops(higher peak memory usage)
  2. moving the xla tensor to CPU will force the main process to block until the device execution to finish which prevent us overlapping the tracing and device execution.

@radna0
Copy link
Author

radna0 commented Sep 27, 2024

Thanks for the detailed explanation, @JackCaoG.

Given the limitations with partitioning graphs and the potential performance hit, do you think it would be feasible for your team to implement something where the model is moved to the XLA device only during the forward pass, without incorporating it into the HLO graph?

For example, in my use case with denoising models, the transformer is only moved to the device during the denoising process. It's solely used to iteratively compute forward passes on the latents to produce the noise, so there’s no need to involve it in backward computations or gradient tracking. It would be extremely helpful if this kind of selective offloading could be supported, as it would allow for significantly reduced memory usage while keeping most of the workload on the XLA devices.

What do you think of this approach? It could offer a middle ground between full CPU offloading and managing memory without major overhead.

@JackCaoG
Copy link
Collaborator

I think we can use what you described is a reasonable milestone we can first target when we start the design and implementation of the CPU offloading, but due to the resource I don't have a timeline for this ask.

@radna0
Copy link
Author

radna0 commented Sep 27, 2024

Do you not have the resource to work with? I can certainly provide you some of the tpu pods that I am not using at the moment? @JackCaoG

@JackCaoG
Copy link
Collaborator

oh by resource I meant I don't have anyone in the team currently have the cycle to work on this immediately.

@radna0
Copy link
Author

radna0 commented Oct 1, 2024

How long would it take for someone on your team to implement this? I really need this to work to move forward and am currently blocked because of this issue. Is there anything I can do or your team can do to make it just work in the meantime?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Oct 2, 2024

I don't think we have anyone to work on this feature this year. I think design and implement it can take from 1 quarter to 2 quarters but until I look at the design I can't be sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants