You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@OhadRubin I'm not familiar with multi_queries_paged_attention but flash_attention custom implementation from torch_xla has default block size set. You can try messing with those settings to use more TPU memory?
@OhadRubin The permanent fix should be in the kernel. In the current paged attention kernel, it preloads all page_indices to smem. When the seq_len is long, the page_indices will be large. Since smem is very limited, you'd likely run into smem oom error. So one idea is to load the page_indices on-demand as opposed to loading everything at once. The issue exists in both paged attention kernels (single query and multi-queries).
I'm running Llama3 70B with vllm on a TPU-v4-16, when using the flash attention kernel i'm able to go up to 16k, but using multi_queries_paged_attention with sequence length 256, it seems that the page table is taking too much smem.
@vanbasten23 @WoosukKwon any idea how to address this (i'm familiar with pallas programming)?
maybe something along the lines of this? https://github.com/vllm-project/vllm/blob/02222a0256f60319f5bcd56d1d036a943d6334f8/vllm/attention/backends/pallas.py#L260
The text was updated successfully, but these errors were encountered: