Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_queries_paged_attention_kernel fails with Llama3 70B on a TPU-v4-16 with sequence length of 256 #8515

Open
OhadRubin opened this issue Dec 21, 2024 · 2 comments

Comments

@OhadRubin
Copy link

OhadRubin commented Dec 21, 2024

I'm running Llama3 70B with vllm on a TPU-v4-16, when using the flash attention kernel i'm able to go up to 16k, but using multi_queries_paged_attention with sequence length 256, it seems that the page table is taking too much smem.
@vanbasten23 @WoosukKwon any idea how to address this (i'm familiar with pallas programming)?
maybe something along the lines of this? https://github.com/vllm-project/vllm/blob/02222a0256f60319f5bcd56d1d036a943d6334f8/vllm/attention/backends/pallas.py#L260

Loading safetensors checkpoint shards: 100% Completed | 30/30 [02:03<00:00,  4.13s/it]                                                                                                                                                                                                            
INFO 12-21 14:11:07 ray_tpu_executor.py:276] # TPU blocks: 19032, # CPU blocks: 6552                                                                                                                                                                                                    
INFO 12-21 14:11:07 tpu_model_runner.py:274] Compiling the model with different input shapes...                                                                                                                                                                                         
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:11:08 tpu_model_runner.py:274] Compiling the model with different input shapes...                                                                                                                                             
(RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 tpu.py:27] Cannot use _Backend.FLASH_ATTN backend on TPU. [repeated 6x across cluster]                                                                                                                                                  
(RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 selector.py:163] Using Pallas backend. [repeated 6x across cluster]                                                                                                                                                                     
(RayWorkerWrapper pid=1005) WARNING 12-21 14:07:13 tpu_worker.py:62] Starting to init distributed environment with config: ParallelConfig(pipeline_parallel_size=1, tensor_parallel_size=8, worker_use_ray=False, max_parallel_loading_workers=None, disable_custom_all_reduce=False, tokenizer_pool_config=None, ray_workers_use_nsight=False, p
lacement_group=<ray.util.placement_group.PlacementGroup object at 0x7f05501350f0>, distributed_executor_backend='ray', worker_cls='vllm.worker.tpu_worker.TPUWorker', sd_worker_cls='auto', world_size=8, rank=3) [repeated 6x across cluster]                                           
(RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 parallel_state.py:954] world_size=8 rank=3 local_rank=3 distributed_init_method=tcp://10.130.0.185:57577 backend=gloo [repeated 6x across cluster]                                                                                      
(RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 parallel_state.py:959] attempting to initialize distributed environment [repeated 6x across cluster]                                                                                                                                    
(RayWorkerWrapper pid=1135, ip=10.130.0.186) init_world_group: local_rank=3 [repeated 12x across cluster]                                                                                                                                                                               
(RayWorkerWrapper pid=1135, ip=10.130.0.186) init_world_group: backend='gloo' [repeated 6x across cluster]                                                                                                                                                                              
(RayWorkerWrapper pid=1135, ip=10.130.0.186) init_model_parallel_group bla bla: local_rank=3 [repeated 26x across cluster]                                                                                                                                                              
(RayWorkerWrapper pid=1135, ip=10.130.0.186) init_model_parallel_group bla bla: backend='gloo' [repeated 13x across cluster]                                                                                                                                                            
(RayWorkerWrapper pid=1005) self.cpu_group=<torch.distributed.distributed_c10d.ProcessGroup object at 0x7f051028d330> [repeated 6x across cluster]                                                                                                                                      
INFO 12-21 14:13:02 tpu_model_runner.py:284] batch_size: 1, seq_len: 16                                                                                                                                                                                                                 
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:02 tpu_model_runner.py:284] batch_size: 1, seq_len: 16                                                                                                                                                                     
(RayWorkerWrapper pid=895, ip=10.130.0.186) INFO 12-21 14:11:08 tpu_model_runner.py:274] Compiling the model with different input shapes... [repeated 6x across cluster]                                                                                                                
INFO 12-21 14:13:05 tpu_model_runner.py:284] batch_size: 1, seq_len: 32                                                                                                                                                                                                                 
INFO 12-21 14:13:07 tpu_model_runner.py:284] batch_size: 1, seq_len: 64                                                                                                                                                                                                                 
(RayWorkerWrapper pid=995) INFO 12-21 14:13:07 tpu_model_runner.py:284] batch_size: 1, seq_len: 64 [repeated 18x across cluster]                                                                                                                                                        
INFO 12-21 14:13:10 tpu_model_runner.py:284] batch_size: 1, seq_len: 128                                                                                                                                                                                                                
INFO 12-21 14:13:12 tpu_model_runner.py:284] batch_size: 1, seq_len: 256                                                                                                                                                                                                                
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:12 tpu_model_runner.py:284] batch_size: 1, seq_len: 256 [repeated 10x across cluster]                                                                                                                                      
INFO 12-21 14:13:15 tpu_model_runner.py:284] batch_size: 1, seq_len: 512                                                                                                                                                                                                                
INFO 12-21 14:13:19 tpu_model_runner.py:284] batch_size: 1, seq_len: 1024                                                                                                                                                                                                               
(RayWorkerWrapper pid=995) INFO 12-21 14:13:19 tpu_model_runner.py:284] batch_size: 1, seq_len: 1024 [repeated 14x across cluster]                                                                                                                                                      
INFO 12-21 14:13:22 tpu_model_runner.py:284] batch_size: 1, seq_len: 2048                                                                                                                                                                                                               
INFO 12-21 14:13:27 tpu_model_runner.py:284] batch_size: 1, seq_len: 4096                                                                                                                                                                                                               
(RayWorkerWrapper pid=995) INFO 12-21 14:13:27 tpu_model_runner.py:284] batch_size: 1, seq_len: 4096 [repeated 14x across cluster]                                                                                                                                                      
INFO 12-21 14:13:32 tpu_model_runner.py:284] batch_size: 1, seq_len: 8192                                                                                                                                                                                                               
(RayWorkerWrapper pid=995) INFO 12-21 14:13:32 tpu_model_runner.py:284] batch_size: 1, seq_len: 8192 [repeated 7x across cluster]                                                                                                                                                       
INFO 12-21 14:13:38 tpu_model_runner.py:284] batch_size: 1, seq_len: 16384                                                                                                                                                                                                              
INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 150.46 s.                                                                                                                                                                                                  
INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill...                                                                                                                                                                      
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 149.52 s.                                                                                                                                                      
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill...                                                                                                                          
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:284] batch_size: 1, seq_len: 16384 [repeated 7x across cluster]                                                                                                                                     
BINFO 12-21 14:15:53 tpu_model_runner.py:306] batch_size: 1, seq_len: 16                                                                                                                               
(RayWorkerWrapper pid=1005) INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 149.50 s. [repeated 6x across cluster]                                                                                                                                         
(RayWorkerWrapper pid=1005) INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill... [repeated 6x across cluster]                                                                                                             
(RayWorkerWrapper pid=995) INFO 12-21 14:15:53 tpu_model_runner.py:306] batch_size: 1, seq_len: 16 [repeated 7x across cluster]                                                                                                                                                         
INFO 12-21 14:16:31 tpu_model_runner.py:306] batch_size: 1, seq_len: 32                                                                                                                                                                                                                 
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:16:31 tpu_model_runner.py:306] batch_size: 1, seq_len: 32 [repeated 7x across cluster]                                                                                                                                        
INFO 12-21 14:17:07 tpu_model_runner.py:306] batch_size: 1, seq_len: 64                                                                                                                                                                                                                 
(RayWorkerWrapper pid=995) INFO 12-21 14:17:07 tpu_model_runner.py:306] batch_size: 1, seq_len: 64 [repeated 7x across cluster]                                                                                                                                                         
INFO 12-21 14:17:48 tpu_model_runner.py:306] batch_size: 1, seq_len: 128                                                                                                                                                                                                                
(RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:17:48 tpu_model_runner.py:306] batch_size: 1, seq_len: 128 [repeated 7x across cluster]                                                                                                                                       
INFO 12-21 14:18:30 tpu_model_runner.py:306] batch_size: 1, seq_len: 256                                                                                                                                                                                                                
(RayWorkerWrapper pid=895, ip=10.130.0.186) INFO 12-21 14:18:30 tpu_model_runner.py:306] batch_size: 1, seq_len: 256 [repeated 7x across cluster]    
@radna0
Copy link

radna0 commented Dec 25, 2024

@OhadRubin I'm not familiar with multi_queries_paged_attention but flash_attention custom implementation from torch_xla has default block size set. You can try messing with those settings to use more TPU memory?

@vanbasten23
Copy link
Collaborator

vanbasten23 commented Jan 7, 2025

@OhadRubin The permanent fix should be in the kernel. In the current paged attention kernel, it preloads all page_indices to smem. When the seq_len is long, the page_indices will be large. Since smem is very limited, you'd likely run into smem oom error. So one idea is to load the page_indices on-demand as opposed to loading everything at once. The issue exists in both paged attention kernels (single query and multi-queries).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants