diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh new file mode 100644 index 0000000000000..12056ec2907f6 --- /dev/null +++ b/.buildkite/run-cpu-test.sh @@ -0,0 +1,14 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.cpu . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-check cpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 4dde733581822..3ed23c62c005d 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -8,6 +8,9 @@ steps: queue: amd command: bash .buildkite/run-amd-test.sh + - label: "CPU Test" + command: bash .buildkite/run-cpu-test.sh + - label: ":docker: build image" commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." diff --git a/CMakeLists.txt b/CMakeLists.txt index 412b9c0cd59e0..9d90f4e7a0496 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) +option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") + message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) @@ -76,6 +79,19 @@ find_package(Torch REQUIRED) find_library(torch_python_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + endif() + return() +endif() + # # Set up GPU language and check the torch version and warn if it isn't # what is expected. diff --git a/Dockerfile.cpu b/Dockerfile.cpu new file mode 100644 index 0000000000000..4251fddd6cc3b --- /dev/null +++ b/Dockerfile.cpu @@ -0,0 +1,20 @@ +# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. + +FROM ubuntu:22.04 + +RUN apt-get update -y \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +RUN pip install --upgrade pip \ + && pip install wheel packaging ninja setuptools>=49.4.0 numpy + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install + +CMD ["/bin/bash"] diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake new file mode 100644 index 0000000000000..0cf37769a6960 --- /dev/null +++ b/cmake/cpu_extension.cmake @@ -0,0 +1,90 @@ +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# +# Define environment variables for special configurations +# +if(DEFINED ENV{VLLM_CPU_AVX512BF16}) + set(ENABLE_AVX512BF16 ON) +endif() + +include_directories("${CMAKE_SOURCE_DIR}/csrc") + +# +# Check the compile flags +# +list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") + +execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + +if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +endif() + +function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + +if (AVX512_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mavx512f" + "-mavx512vl" + "-mavx512bw" + "-mavx512dq") + + find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) + if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + else() + message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") + endif() +else() + message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + + +# +# Define extension targets +# + +# +# _C extension +# +set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/attention.cpp" + "csrc/cpu/cache.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/cpu/pybind.cpp") + +define_gpu_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + WITH_SOABI +) + +add_custom_target(default) +message(STATUS "Enabling C extension.") +add_dependencies(default _C) + diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp new file mode 100644 index 0000000000000..1bd24eb79d129 --- /dev/null +++ b/csrc/cpu/activation.cpp @@ -0,0 +1,148 @@ +#include "cpu_types.hpp" + +namespace { +template +void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, + scalar_t *__restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + int start = i * d; + if constexpr (is_gated) { + start *= 2; + } + + const scalar_vec_t x(input + start + j); + const vec_op::FP32Vec8 f32_x(x); + vec_op::FP32Vec8 f32_ans = func(f32_x); + + if constexpr (is_gated) { + const scalar_vec_t y(input + start + d + j); + const vec_op::FP32Vec8 f32_y(y); + f32_ans = f32_y * f32_ans; + } + + const scalar_vec_t result(f32_ans); + result.save(output + i * d + j); + } + } +} + +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 x3 = x * x * x; + const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); + const vec_op::FP32Vec8 w3(0.044715); + const vec_op::FP32Vec8 x_3 = x * x * x; + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); +} +}; // namespace + +void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); +} + +void gelu_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); +} + +void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_tanh_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl) + }); +} + +void gelu_new(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_new_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_new_impl) + }); +} + +void gelu_fast(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_fast_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); +} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp new file mode 100644 index 0000000000000..6f38e923d7d6f --- /dev/null +++ b/csrc/cpu/attention.cpp @@ -0,0 +1,744 @@ +#include "cpu_types.hpp" + +namespace { + +template struct KernelVecType { + using q_load_vec_type = void; + using q_vec_type = void; + using k_load_vec_type = void; + using k_vec_type = void; + using qk_acc_vec_type = void; + using v_load_vec_type = void; +}; + +template <> struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec4; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP32Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512BF16__ +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::BF16Vec32; + using k_vec_type = vec_op::BF16Vec32; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#else +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#endif + +template +FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, + const int capacity) { + T max = data[0]; + for (int i = 1; i < size; ++i) { + max = max >= data[i] ? max : data[i]; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE std::pair +reduceSoftmaxAlibi(T *data, const int size, const int capacity, + const float alibi_slope, const int start_index, + const int context_len) { + data[0] += alibi_slope * (start_index - context_len + 1); + T max = data[0]; + for (int i = 1; i < size; ++i) { + T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + data[i] = qk; + max = max >= qk ? max : qk; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, + const int size) { + T max = max_data[0]; + for (int i = 1; i < size; ++i) { + max = max >= max_data[i] ? max : max_data[i]; + } + + T rescaled_sum = 0; + for (int i = 0; i < size; ++i) { + T rescale_factor = std::exp(max_data[i] - max); + rescaled_sum += rescale_factor * sum_data[i]; + sum_data[i] *= rescale_factor; + } + for (int i = 0; i < size; ++i) { + sum_data[i] /= rescaled_sum + 1e-8; + } +} + +template +struct reduceQKBlockKernel { + using q_load_vec_type = typename KernelVecType::q_load_vec_type; + using q_vec_type = typename KernelVecType::q_vec_type; + using k_load_vec_type = typename KernelVecType::k_load_vec_type; + using k_vec_type = typename KernelVecType::k_vec_type; + using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; + + constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; + constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; + constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; + + static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); + static_assert(k_load_vec_type::get_elem_num() % x == 0); + static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); + + FORCE_INLINE static void call(const scalar_t *__restrict__ q, + const scalar_t *__restrict__ k_block, + float *__restrict__ logits, float scale, + const int token_num) { + const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; + + qk_acc_vec_type group_accums[MAX_GROUP_NUM]; + if (token_num == BLOCK_SIZE) { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + + vec_op::unroll_loop( + [k_block, &q_group_vec, &group_accums](int token_group_idx) { + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } else { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + for (int token_group_start = 0; token_group_start < group_num; + token_group_start += UNROLL_GROUP_NUM) { + vec_op::unroll_loop( + [token_group_start, k_block, &q_group_vec, + &group_accums](int token_group_idx) { + token_group_idx += token_group_start; + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } + } + + for (int token_group_idx = 0; token_group_idx < group_num; + ++token_group_idx) { + vec_op::unroll_loop( + [&group_accums, logits, scale, token_group_idx](int token_idx) { + float dot_v = + group_accums[token_group_idx] + .template reduce_sub_sum(token_idx); + logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = + dot_v * scale; + }); + } + } +}; + +template +FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, + acc_t &&acc) { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); + static_assert(BLOCK_SIZE == ELEM_NUM); + vec_op::FP32Vec16 prob_vec(prob); + + vec_op::unroll_loop([&](int head_elem_idx) { + v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); + vec_op::FP32Vec16 fp32_v_vec(v_vec); + acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; + }); +} +}; // namespace + +// Paged attention v1 +namespace { +template +struct paged_attention_v1_impl { + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + + int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + + const int parallel_work_item_num = omp_get_max_threads(); + + size_t logits_bytes = + parallel_work_item_num * max_context_len_padded * sizeof(float); + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_context_len_padded] + +#pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int context_len = context_lens[seq_idx]; + const int *seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const int last_block_token_num = + context_len - (block_num - 1) * BLOCK_SIZE; + float *__restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_context_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + // Compute softmax + if (alibi_slopes) { + reduceSoftmaxAlibi(thread_block_logits, context_len, + block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, + context_len); + } else { + reduceSoftmax(thread_block_logits, context_len, + block_num * BLOCK_SIZE); + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + std::free(logits); + } +}; + +#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + +template +void paged_attention_v1_impl_launcher( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + context_lens, max_context_len, alibi_slopes); + +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, + int num_kv_heads, float scale, + torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype) { + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); +} + +// Paged attention v2 +namespace { +template +struct paged_attention_v2_impl { + static void call( + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float + *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads, const int max_num_partitions) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); + static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); + +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int partition_idx = 0; partition_idx < max_num_partitions; + ++partition_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int start_token_idx = partition_idx * PARTITION_SIZE; + + if (start_token_idx >= context_len) + continue; + + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + const bool no_reduce = (partition_num == 1); + const int context_token_num = + (std::min(context_len, start_token_idx + PARTITION_SIZE) - + start_token_idx); + const int block_num = + (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_token_num = + context_token_num - (block_num - 1) * BLOCK_SIZE; + const int *seq_block_table = block_tables + + max_num_blocks_per_seq * seq_idx + + start_token_idx / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + std::pair max_and_sum; + if (alibi_slopes) { + max_and_sum = reduceSoftmaxAlibi( + logits, context_token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, context_len); + } else { + max_and_sum = reduceSoftmax(logits, context_token_num, + block_num * BLOCK_SIZE); + } + + auto &&[max_logit, exp_sum] = max_and_sum; + + scalar_t *__restrict__ output_buffer = nullptr; + if (!no_reduce) { + auto idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + max_logits[idx] = max_logit; + exp_sums[idx] = exp_sum; + output_buffer = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + output_buffer = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + output_buffer + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + } + + // Rescale partition softmax and store the factors to exp_sums +#pragma omp parallel for collapse(2) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + reducePartitonSoftmax( + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + partition_num); + } + } + + // Reduce values + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); + constexpr int head_elem_num_per_group = + 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE + // didn't align with 64 bytes + static_assert(HEAD_SIZE % head_elem_num_per_group == 0); + constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; + const float *__restrict__ rescale_factors = exp_sums; +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + const float *__restrict__ seq_head_rescale_factors = + rescale_factors + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + const scalar_t *__restrict__ seq_head_tmp_out = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + group_idx * head_elem_num_per_group; + scalar_t *__restrict__ seq_head_output = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + group_idx * head_elem_num_per_group; + + vec_op::FP32Vec16 acc; + for (int i = 0; i < partition_num; ++i) { + vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); + v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); + vec_op::FP32Vec16 fp32_value(value); + acc = acc + fp32_value * rescale_factor; + } + v_load_vec_type cast_acc(acc); + cast_acc.save(seq_head_output); + } + } + } + } +}; + +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ + max_num_partitions); + +template +void paged_attention_v2_impl_launcher( + torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, + torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + int max_num_partitions = exp_sums.size(-1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, block_size, \ + max_context_len, alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype) { + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); +} diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp new file mode 100644 index 0000000000000..94f5affc39f02 --- /dev/null +++ b/csrc/cpu/cache.cpp @@ -0,0 +1,139 @@ +#include +#include + +#include "cpu_types.hpp" + +namespace { +template +void copy_blocks_cpu_impl( + std::vector &key_caches, + std::vector &value_caches, + const std::vector> mapping_pairs, + const int element_num_per_block, const int layer_num) { + const size_t pair_num = mapping_pairs.size(); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; +#pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { + int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t target_offset = + element_num_per_block * mapping_pairs[pair].second; + scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t *source_ptr = key_cache_ptr + source_offset; + scalar_t *target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + + scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + } + } +} + +template +void reshape_and_cache_cpu_impl( + const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, + scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx >= 0) { + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; + const scalar_t *src_key_head_ptr = key + src_key_head_idx; + const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + scalar_t *target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + scalar_t *target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + + for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { + const int64_t target_offset = + src_key_idx * block_size + block_offset * x; + for (int i = 0; i < x; ++i) { + target_key_head_ptr[target_offset + i] = + src_key_head_ptr[src_key_idx + i]; + } + } + + for (int src_value_idx = 0; src_value_idx < head_size; + ++src_value_idx) { + const int64_t target_offset = + src_value_idx * block_size + block_offset; + target_value_head_ptr[target_offset] = + src_value_head_ptr[src_value_idx]; + } + } + } + } +} +}; // namespace + +void copy_blocks(std::vector &key_caches, + std::vector &value_caches, + const std::map> &block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + + std::vector> mapping_pairs; + mapping_pairs.reserve(block_mapping.size()); + for (const auto &pair : block_mapping) { + for (const auto &dst : pair.second) { + mapping_pairs.emplace_back(pair.first, dst); + } + } + + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); +} + +void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, + value_stride, num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); +} + +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") +} diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp new file mode 100644 index 0000000000000..c1d3ec058b991 --- /dev/null +++ b/csrc/cpu/cpu_types.hpp @@ -0,0 +1,352 @@ + +#ifndef CPU_TYPES_HPP +#define CPU_TYPES_HPP + +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#endif + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp new file mode 100644 index 0000000000000..467f0dc84982c --- /dev/null +++ b/csrc/cpu/layernorm.cpp @@ -0,0 +1,117 @@ +#include "cpu_types.hpp" + +namespace { +template +void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + vec_op::FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_w(w); + + vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } +} + +template +void fused_add_rms_norm_impl(scalar_t *__restrict__ input, + scalar_t *__restrict__ residual, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto residual_p = residual + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t res(residual_p + j); + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_res(res); + + fp32_x = fp32_x + fp32_res; + variance = variance + fp32_x * fp32_x; + scalar_vec_t out(fp32_x); + out.save(residual_p + j); + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t w(weight + j); + scalar_vec_t res(residual_p + j); + + vec_op::FP32Vec8 fp32_w(w); + vec_op::FP32Vec8 fp32_res(res); + + vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(input_p + j); + } + } +} +} // namespace + +void rms_norm(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) + }); +} + +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "fused_add_rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl) + fused_add_rms_norm_impl( + input.data_ptr(), residual.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl) + }); +} diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp new file mode 100644 index 0000000000000..e9b3992204bb2 --- /dev/null +++ b/csrc/cpu/pos_encoding.cpp @@ -0,0 +1,199 @@ + +#include "cpu_types.hpp" + +namespace { +template +void rotary_embedding_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + constexpr int ELEM_SIZE = sizeof(scalar_t); + + const int embed_dim = rot_dim / 2; + TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(query + out_x); + const scalar_vec_t q_y(query + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(query + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(query + out_y); + } + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t k_x(key + out_x); + const scalar_vec_t k_y(key + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_k_x(k_x); + vec_op::FP32Vec8 fp32_k_y(k_y); + + auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; + scalar_vec_t(out1).save(key + out_x); + auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; + scalar_vec_t(out2).save(key + out_y); + } + } + } +} + +template +void rotary_embedding_gptj_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + const int embed_dim = rot_dim / 2; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + scalar_t *head_query = token_head + query; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_query[x_index]; + const float y = head_query[y_index]; + + head_query[x_index] = x * cos - y * sin; + head_query[y_index] = y * cos + x * sin; + } + } + } + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_kv_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + scalar_t *head_key = key + token_head; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_key[x_index]; + const float y = head_key[y_index]; + + head_key[x_index] = x * cos - y * sin; + head_key[y_index] = y * cos + x * sin; + } + } + } +} +}; // namespace + +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding_impl", [&] { + CPU_KERNEL_GUARD_IN(rotary_embedding_impl) + if (is_neox) { + rotary_embedding_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } else { + rotary_embedding_gptj_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } + + CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) + }); +} diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp new file mode 100644 index 0000000000000..bba044087f37c --- /dev/null +++ b/csrc/cpu/pybind.cpp @@ -0,0 +1,73 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); +} diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst new file mode 100644 index 0000000000000..ba8b0645adcdf --- /dev/null +++ b/docs/source/getting_started/cpu-installation.rst @@ -0,0 +1,87 @@ +.. _installation_cpu: + +Installation with CPU +======================== + +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. + +Table of contents: + +#. :ref:`Requirements ` +#. :ref:`Quick start using Dockerfile ` +#. :ref:`Build from source ` +#. :ref:`Performance tips ` + +.. _cpu_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Compiler: gcc/g++>=12.3.0 (recommended) +* Instruction set architecture (ISA) requirement: AVX512 is required. + +.. _cpu_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g . + $ docker run -it \ + --rm \ + --network=host \ + --cpuset-cpus= \ + --cpuset-mems= \ + vllm-cpu-env + +.. _build_cpu_backend_from_source: + +Build from source +----------------- + +- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: + +.. code-block:: console + + $ sudo apt-get update -y + $ sudo apt-get install -y gcc-12 g++-12 + $ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +- Second, install Python packages for vLLM CPU backend building: + +.. code-block:: console + + $ pip install --upgrade pip + $ pip install wheel packaging ninja setuptools>=49.4.0 numpy + $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +- Finally, build and install vLLM CPU backend: + +.. code-block:: console + + $ VLLM_TARGET_DEVICE=cpu python setup.py install + +.. note:: + - BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support. + + - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. + + - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. + +.. _cpu_backend_performance_tips: + +Performance tips +----------------- + +- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. + +- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. + +- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful. + + + diff --git a/docs/source/index.rst b/docs/source/index.rst index 5196ef062dc19..390409204cbc3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ Documentation getting_started/installation getting_started/amd-installation getting_started/neuron-installation + getting_started/cpu-installation getting_started/quickstart .. toctree:: diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000..580bffea5a018 --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,15 @@ +cmake>=3.21 +ninja # For faster builds. +psutil +ray >= 2.9 +sentencepiece # Required for LLaMA tokenizer. +numpy +transformers >= 4.38.0 # Required for Gemma. +fastapi +uvicorn[standard] +pydantic >= 2.0 # Required for OpenAI server. +prometheus_client >= 0.18.0 +torch == 2.1.2+cpu +triton >= 2.1.0 +filelock == 3.13.3 +py-cpuinfo \ No newline at end of file diff --git a/setup.py b/setup.py index 225fda0a0b412..e80226faa4807 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) +# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] +VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") # vLLM only supports Linux platform assert sys.platform.startswith( @@ -112,6 +114,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir), '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp), + '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] verbose = bool(int(os.getenv('VERBOSE', '0'))) @@ -185,11 +188,14 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return torch.version.cuda is not None and not _is_neuron() + return VLLM_TARGET_DEVICE == "cuda" \ + and torch.version.cuda is not None \ + and not _is_neuron() def _is_hip() -> bool: - return torch.version.hip is not None + return (VLLM_TARGET_DEVICE == "cuda" + or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None def _is_neuron() -> bool: @@ -201,6 +207,10 @@ def _is_neuron() -> bool: return torch_neuronx_installed +def _is_cpu() -> bool: + return VLLM_TARGET_DEVICE == "cpu" + + def _install_punica() -> bool: return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) @@ -296,6 +306,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_cpu(): + version += "+cpu" else: raise RuntimeError("Unknown runtime environment") @@ -322,6 +334,9 @@ def get_requirements() -> List[str]: elif _is_neuron(): with open(get_path("requirements-neuron.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_cpu(): + with open(get_path("requirements-cpu.txt")) as f: + requirements = f.read().strip().split("\n") else: raise ValueError( "Unsupported platform, please use CUDA, ROCM or Neuron.") diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000..4f69ebef662cb --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,253 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": + return TorchSDPAMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + prompt_lens: Optional[List[int]] + prompt_lens_tensor: Optional[torch.Tensor] + num_prompt_tokens: int + num_generation_tokens: int + + max_subquery_len: Optional[int] = None + max_prompt_len: Optional[int] = None + subquery_start_loc: Optional[torch.Tensor] = None + seq_start_loc: Optional[torch.Tensor] = None + use_cuda_graph: bool = False + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + +class TorchSDPABackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + assert len(alibi_slopes) == num_heads + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) + + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.prompt_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.prompt_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): + end = start + prompt_len + sub_out = scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=not self.need_mask, + scale=self.scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + else: + # prefix-enabled attention + raise RuntimeError( + "Torch SDPA backend doesn't support prefix decoding.") + + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + prompt_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, prompt_len, prompt_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + prompt_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + tensor = torch.full( + (1, prompt_len, prompt_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c2ec4376c9f3c..b5cd39bbe6252 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -5,7 +5,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) @@ -17,6 +17,10 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend + elif is_cpu(): + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -29,6 +33,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool: # AMD GPUs. logger.info("Cannot use FlashAttention backend for AMD GPUs.") return False + if is_cpu(): + return False if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " diff --git a/vllm/config.py b/vllm/config.py index 903829d8b176d..eef3fc53c3a65 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, + is_neuron) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -598,6 +599,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_cpu(): + self.device_type = "cpu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e9f72c17bf8dc..8d61f2f9ff193 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -332,7 +332,7 @@ def add_cli_args( parser.add_argument("--device", type=str, default=EngineArgs.device, - choices=["auto", "cuda", "neuron"], + choices=["auto", "cuda", "neuron", "cpu"], help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser.add_argument( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dec42c633b10b..7047b23bbe27f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -178,6 +178,9 @@ def from_engine_args( if device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor + elif device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor elif parallel_config.worker_use_ray: initialize_ray_cluster(parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py new file mode 100644 index 0000000000000..7b3cc784c98e5 --- /dev/null +++ b/vllm/executor/cpu_executor.py @@ -0,0 +1,154 @@ +import os +from typing import Dict, List, Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + +logger = init_logger(__name__) + + +class CPUExecutor(ExecutorBase): + + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: + assert device_config.device_type == "cpu" + assert lora_config is None, "cpu backend doesn't support LoRA" + model_config = _verify_and_get_model_config(model_config) + cache_config = _verify_and_get_cache_config(cache_config) + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + # Instantiate the worker and load the model to CPU. + self._init_worker() + self._init_cache() + + def _init_worker(self): + from vllm.worker.cpu_worker import CPUWorker + + assert self.parallel_config.world_size == 1, ( + "CPUExecutor only supports single CPU socket currently.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = CPUWorker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _init_cache(self) -> None: + num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num( + block_size=self.cache_config.block_size, + cache_space=self.cache_config.cpu_kvcache_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + + logger.info(f"# CPU blocks: {num_cpu_blocks}") + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore + self.cache_config.num_cpu_blocks = 0 # type: ignore + + # Initialize the cache. + self.driver_worker.init_cache_engine(cache_config=self.cache_config) + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + output = self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def list_loras(self) -> List[int]: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def check_health(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.float16: + logger.warning("float16 is not supported on CPU, casting to bfloat16.") + config.dtype = torch.bfloat16 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + _GB = 1 << 30 + if config.enable_prefix_caching: + logger.warning("Prefix caching is not supported on CPU, disable it.") + config.enable_prefix_caching = False + + kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") + kv_cache_space = int(kv_cache_space_str) + + if kv_cache_space >= 0: + if kv_cache_space == 0: + config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore + logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/utils.py b/vllm/utils.py index 93fff4ffc9361..17b97f393ff21 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -117,6 +117,13 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import version + is_cpu_flag = "cpu" in version("vllm") + return is_cpu_flag + + @lru_cache(maxsize=None) def is_neuron() -> bool: try: @@ -362,6 +369,9 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif is_cpu(): + print_warning_once("Pin memory is not supported on CPU.") + return False return True diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 0000000000000..262ed9abd36b7 --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,280 @@ +"""A CPU worker class.""" +from typing import Dict, List, Optional + +import torch +import torch.distributed + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.worker.model_runner import ModelRunner + +logger = init_logger(__name__) + + +class CPUModelRunner(ModelRunner): + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend(model_config.dtype) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker: + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = CPUModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cpu_cache = None + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def get_cpu_cache_block_num( + self, + block_size: int, + cache_space: int, + cache_dtype: str, + ) -> int: + """ + Args: + block_size: The size of the cache block. + cache_space: The size of the CPU KV cache space in bytes. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = CPUCacheEngine.get_cache_block_size( + block_size, cache_dtype, self.model_config, self.parallel_config) + num_cpu_blocks = int(cache_space // cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + return num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CPUCacheEngine(self.cache_config, + self.model_config, + self.parallel_config, + self.device_config) + self.cpu_cache = self.cache_engine.cpu_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.cpu_cache is not None + + # Populate the cache to warmup the memory + for layer_cache in self.cpu_cache: + layer_cache.fill_(0) + + def cache_copy( + self, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + assert len(blocks_to_swap_in) == 0 + assert len(blocks_to_swap_out) == 0 + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.cpu_cache) + return output + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + backend = "gloo" + torch.distributed.init_process_group( + backend=backend, + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size)