Skip to content

Commit 82b6aa0

Browse files
authored
Integrate TPP code into IPEX (#1357)
* Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Enable build * Libxsmm initialization python binding * Add python binding for tpp/optimizier and tpp/pad/bert module * Squad inference get expected result * SQuad training pass * Add all tpp related code into torch_ipex::tpp namespace * Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding * Enbale unpad fp32 path * Enbale bf16 unpad * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Enable build * Libxsmm initialization python binding * Add python binding for tpp/optimizier and tpp/pad/bert module * Squad inference get expected result * SQuad training pass * Add all tpp related code into torch_ipex::tpp namespace * Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding * Enbale unpad fp32 path * Enbale bf16 unpad * fix unpad import issue * Enable ipex.tpp_optimize API * Fix transformers check issue * Update frontend.py * Unify unpad/paded code based the original unpad code * Remove unused paded code * Remove extend_profiler.py * Add UT for tpp * Add UT for backward * Update Module.cpp * 1) tpp_optimize -> fast_bert 2) add tpp prefix for optmizer related python api 3) fix the optimizer replace issue for non- SGD/AdamW * Create README.md * Create README.md * Fix clang format issue * Fix UT fail when transformers is not installed * Update test_tpp_ops.py * Fix UT fail when transformers is not installed * Fix UT fail * Fix UT fail
1 parent 82fb517 commit 82b6aa0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+13359
-5
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "third_party/ideep"]
55
path = third_party/ideep
66
url = https://github.com/intel/ideep.git
7+
[submodule "third_party/libxsmm"]
8+
path = third_party/libxsmm
9+
url = https://github.com/libxsmm/libxsmm.git

csrc/cpu/CMakeLists.txt

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(IPEX_CPU_CPP_ISA_SRCS)
4646
set(IPEX_CPU_CPP_TOOLKIT_SRCS)
4747
set(IPEX_CPU_CPP_IDEEP_SRCS)
4848
set(IPEX_CPU_CPP_RUNTIME_SRCS)
49+
set (IPEX_CPU_CPP_TPP_SRCS)
4950

5051
set(IPEX_JIT_CPP_SRCS)
5152
set(IPEX_UTLIS_CPP_SRCS)
@@ -62,13 +63,13 @@ add_subdirectory(${IPEX_CPU_ROOT_DIR}/isa)
6263
add_subdirectory(${IPEX_CPU_ROOT_DIR}/toolkit)
6364
add_subdirectory(${IPEX_CPU_ROOT_DIR}/runtime)
6465
add_subdirectory(${IPEX_CPU_ROOT_DIR}/utils)
66+
add_subdirectory(${IPEX_CPU_ROOT_DIR}/tpp)
6567

6668
add_subdirectory(${IPEX_JIT_CPP_ROOT} jit_cpu)
6769
add_subdirectory(${IPEX_UTLIS_CPP_ROOT} csrc_utlis)
6870

6971
set(IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_DYNDISP_SRCS} ${IPEX_CPU_CPP_ISA_SRCS_GEN} ${IPEX_CPU_CPP_UTILS_SRCS} ${IPEX_CPU_CPP_QUANTIZATION_SRCS} ${IPEX_JIT_CPP_SRCS}
70-
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_UTLIS_CPP_SRCS}
71-
${IPEX_CPU_CPP_TOOLKIT_SRCS})
72+
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} ${IPEX_CPU_CPP_TPP_SRCS})
7273

7374
list(REMOVE_ITEM IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_ISA_SRCS_ORIGIN})
7475

@@ -81,10 +82,12 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_ROOT_DIR})
8182
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR})
8283
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/aten)
8384
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/utils)
85+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/tpp)
8486

8587
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_JIT_CPP_ROOT})
8688
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_UTLIS_CPP_ROOT})
8789

90+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
8891
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/include)
8992
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/third_party/oneDNN/include)
9093
# TODO: once llga is merged into oneDNN, use oneDNN directly as the third_party instead of using that inside llga
@@ -101,6 +104,18 @@ if(CLANG_FORMAT)
101104
add_dependencies(${PLUGIN_NAME_CPU} CL_FORMAT_CPU_NATIVE_CSRC)
102105
endif()
103106

107+
include(${CMAKE_ROOT}/Modules/ExternalProject.cmake)
108+
ExternalProject_Add(libxsmm
109+
SOURCE_DIR ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm
110+
BUILD_IN_SOURCE 1
111+
CONFIGURE_COMMAND ""
112+
BUILD_COMMAND
113+
make
114+
"AVX=3"
115+
"-j"
116+
INSTALL_COMMAND ""
117+
)
118+
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/lib/libxsmm.a)
104119
add_dependencies(${PLUGIN_NAME_CPU} dnnl_graph)
105120
# If Graph Compiler is built, then it should link to its LLVM dependencies,
106121
# and not the LLVM symbols exposed by PyTorch.
@@ -114,7 +129,6 @@ if (DEFINED ENV{DNNL_GRAPH_BUILD_COMPILER_BACKEND})
114129
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs=${DNNL_GRAPHCOMPILER_LLVM_LIB_EXCLUDE}")
115130
endif()
116131
endif()
117-
118132
find_package(oneMKL QUIET)
119133
if (ONEMKL_FOUND)
120134
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEMKL_INCLUDE_DIR})

csrc/cpu/tpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
FILE(GLOB _TPP_SRCS *.cpp bert/*.cpp)
2+
LIST(APPEND IPEX_CPU_CPP_TPP_SRCS ${_TPP_SRCS})
3+
# LIST(APPEND IPEX_CPU_CPP_ATEN_SRCS ${_CPU_KERNELS_SRCS})
4+
message(STATUS "IPEX_CPU_CPP_TPP_SRCS: ${IPEX_CPU_CPP_TPP_SRCS}")
5+
# Pass to parent
6+
set(IPEX_CPU_CPP_TPP_SRCS ${IPEX_CPU_CPP_TPP_SRCS} PARENT_SCOPE)

csrc/cpu/tpp/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Overview
2+
This directory mainly includes the tpp related tools and fused kernel implementation based on tpp optimization.
3+
4+
```
5+
├── bert #fused kernel based on tpp
6+
│ ├── fused_bert.cpp
7+
│ ├── fused_dense_dropout_layernorm_bwd_tmpl.h #backard for fused linear+dropout+layernorm
8+
│ ├── fused_dense_dropout_layernorm_fwd_tmpl.h #forward for fused linear+dropout+layernorm
9+
│ ├── fused_dense_gelu_bwd_tmpl.h #backward for fused linear+gelu
10+
│ ├── fused_dense_gelu_fwd_tmpl.h #forward for fused linear+gelu
11+
│ ├── fused_embedding_layernorm_dropout_bwd_tmpl.h #forward for fused embeeding+add+layernorm+dropout
12+
│ ├── fused_embedding_layernorm_dropout_fwd_tmpl.h #backard for fused embeeding+add+layernorm+dropout
13+
│ ├── fused_self_attention_bwd_tmpl.h #fused backward self-attention
14+
│ └── fused_self_attention_fwd_tmpl.h #fused forward self-attention
15+
├── CMakeLists.txt
16+
├── common_loops.cpp #loops generation and tuning
17+
├── ext_tpp.h
18+
├── init.cpp
19+
├── jit_compile.cpp
20+
├── jit_compile.h
21+
├── optim.cpp
22+
├── optim.h
23+
├── par_loop_generator.cpp #loops generation and tuning
24+
├── par_loop_generator.h #loops generation and tuning
25+
├── rtm.h
26+
├── tensor_helper.h
27+
├── threaded_loops.h
28+
├── timing.h
29+
├── utils.h
30+
└── xsmm_functors.h #the tpp definition based on libxsmm
31+
```

csrc/cpu/tpp/bert/fused_bert.cpp

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
2+
#include <ATen/record_function.h>
3+
//#include <torch/csrc/autograd/VariableTypeUtils.h>
4+
//#include <torch/extension.h>
5+
6+
#include <dyndisp/DispatchStub.h>
7+
#include <torch/all.h>
8+
#include <iostream>
9+
#include <vector>
10+
#include "ext_tpp.h"
11+
//#include "init.h"
12+
#include "tensor_helper.h"
13+
#include "threaded_loops.h"
14+
#include "timing.h"
15+
#include "xsmm_functors.h"
16+
17+
namespace torch_ipex {
18+
namespace tpp {
19+
20+
static int my_rank = guess_mpi_rank();
21+
22+
REGISTER_LOCAL_SCOPE(b_emb, "b_emb");
23+
REGISTER_LOCAL_SCOPE(q_gemm, "q_gemm");
24+
REGISTER_LOCAL_SCOPE(k_gemm, "k_gemm");
25+
REGISTER_LOCAL_SCOPE(v_gemm, "v_gemm");
26+
REGISTER_LOCAL_SCOPE(ac_gemm, "ac_gemm");
27+
REGISTER_LOCAL_SCOPE(o_gemm, "o_gemm");
28+
REGISTER_LOCAL_SCOPE(i_gemm, "i_gemm");
29+
30+
REGISTER_LOCAL_SCOPE(db_emb, "db_emb");
31+
REGISTER_LOCAL_SCOPE(diq_gemm, "diq_gemm");
32+
REGISTER_LOCAL_SCOPE(dik_gemm, "dik_gemm");
33+
REGISTER_LOCAL_SCOPE(div_gemm, "div_gemm");
34+
REGISTER_LOCAL_SCOPE(dica_gemm, "dica_gemm");
35+
REGISTER_LOCAL_SCOPE(dii_gemm, "dii_gemm");
36+
REGISTER_LOCAL_SCOPE(dio_gemm, "dio_gemm");
37+
REGISTER_LOCAL_SCOPE(dwqkv_gemm, "dwqkv_gemm");
38+
REGISTER_LOCAL_SCOPE(dwq_gemm, "dwq_gemm");
39+
REGISTER_LOCAL_SCOPE(dwk_gemm, "dwk_gemm");
40+
REGISTER_LOCAL_SCOPE(dwv_gemm, "dwv_gemm");
41+
REGISTER_LOCAL_SCOPE(dwa_gemm, "dwa_gemm");
42+
REGISTER_LOCAL_SCOPE(dwc_gemm, "dwc_gemm");
43+
REGISTER_LOCAL_SCOPE(dac_gemm, "dac_gemm");
44+
REGISTER_LOCAL_SCOPE(dwi_gemm, "dwi_gemm");
45+
REGISTER_LOCAL_SCOPE(dwo_gemm, "dwo_gemm");
46+
REGISTER_LOCAL_SCOPE(dqkv_bias, "dqkv_bias");
47+
REGISTER_LOCAL_SCOPE(di_bias, "di_bias");
48+
REGISTER_LOCAL_SCOPE(do_bias, "do_bias");
49+
50+
template <typename T>
51+
inline void omp_reduce_buf(
52+
int num_threads,
53+
int N,
54+
float** ptrs,
55+
T* buf,
56+
bool accumulate = false) {
57+
ScopedTimer _t(EW_RED);
58+
#pragma omp for
59+
for (int i = 0; i < N; i++) {
60+
float sum = 0.0;
61+
for (int j = 0; j < num_threads; j++) {
62+
sum += ptrs[j][i];
63+
}
64+
if (accumulate) {
65+
buf[i] += sum;
66+
} else {
67+
buf[i] = sum;
68+
}
69+
}
70+
}
71+
72+
static std::vector<at::Tensor> fused_self_attention_fwd_unpad(
73+
double p,
74+
std::vector<at::Tensor> inputs,
75+
bool training) {
76+
GlobalPass _gp(FWD);
77+
if (inputs[6].dtype() == at::kFloat) {
78+
typedef float T;
79+
#include "fused_self_attention_fwd_tmpl.h"
80+
} else {
81+
typedef bfloat16 T;
82+
#include "fused_self_attention_fwd_tmpl.h"
83+
}
84+
}
85+
86+
static std::vector<at::Tensor> fused_self_attention_bwd_unpad(
87+
double p,
88+
std::vector<at::Tensor> inputs) {
89+
GlobalPass _gp(BWD);
90+
if (inputs[0].dtype() == at::kFloat) {
91+
typedef float T;
92+
#include "fused_self_attention_bwd_tmpl.h"
93+
} else {
94+
typedef bfloat16 T;
95+
#include "fused_self_attention_bwd_tmpl.h"
96+
}
97+
}
98+
99+
static std::vector<at::Tensor> fused_dense_dropout_layernorm_fwd_unpad(
100+
double p,
101+
double eps,
102+
std::vector<at::Tensor> inputs,
103+
bool training) {
104+
GlobalPass _gp(FWD);
105+
if (inputs[0].dtype() == at::kFloat) {
106+
typedef float T;
107+
#include "fused_dense_dropout_layernorm_fwd_tmpl.h"
108+
} else {
109+
typedef bfloat16 T;
110+
#include "fused_dense_dropout_layernorm_fwd_tmpl.h"
111+
}
112+
}
113+
114+
static std::vector<at::Tensor> fused_dense_dropout_layernorm_bwd_unpad(
115+
double p,
116+
std::vector<at::Tensor> inputs) {
117+
GlobalPass _gp(BWD);
118+
if (inputs[0].dtype() == at::kFloat) {
119+
typedef float T;
120+
#include "fused_dense_dropout_layernorm_bwd_tmpl.h"
121+
} else {
122+
typedef bfloat16 T;
123+
#include "fused_dense_dropout_layernorm_bwd_tmpl.h"
124+
}
125+
}
126+
127+
static std::vector<at::Tensor> fused_dense_gelu_fwd_unpad(
128+
at::Tensor t_in,
129+
at::Tensor t_wt,
130+
at::Tensor t_bias,
131+
bool training) {
132+
GlobalPass _gp(FWD);
133+
if (t_in.dtype() == at::kFloat) {
134+
typedef float T;
135+
#include "fused_dense_gelu_fwd_tmpl.h"
136+
} else {
137+
typedef bfloat16 T;
138+
#include "fused_dense_gelu_fwd_tmpl.h"
139+
}
140+
}
141+
142+
static std::vector<at::Tensor> fused_dense_gelu_bwd_unpad(
143+
at::Tensor t_grad_out,
144+
at::Tensor t_gelu_in,
145+
at::Tensor t_in,
146+
at::Tensor t_wt) {
147+
GlobalPass _gp(BWD);
148+
if (t_grad_out.dtype() == at::kFloat) {
149+
typedef float T;
150+
#include "fused_dense_gelu_bwd_tmpl.h"
151+
} else {
152+
typedef bfloat16 T;
153+
#include "fused_dense_gelu_bwd_tmpl.h"
154+
}
155+
}
156+
157+
static std::vector<at::Tensor> fused_embedding_layernorm_dropout_fwd_unpad(
158+
double p,
159+
double eps,
160+
long H,
161+
long pad_id,
162+
std::vector<at::Tensor> inputs,
163+
bool training) {
164+
GlobalPass _gp(FWD);
165+
if (inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) {
166+
typedef float T;
167+
typedef float ET;
168+
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
169+
} else if (
170+
inputs[4].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) {
171+
typedef bfloat16 T;
172+
typedef float ET;
173+
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
174+
} else if (
175+
inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) {
176+
typedef float T;
177+
typedef bfloat16 ET;
178+
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
179+
} else if (
180+
inputs[4].dtype() == at::kBFloat16 &&
181+
inputs[6].dtype() == at::kBFloat16) {
182+
typedef bfloat16 T;
183+
typedef bfloat16 ET;
184+
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
185+
} else {
186+
PCL_ASSERT(0, "Should not come here\n");
187+
}
188+
}
189+
190+
static std::vector<at::Tensor> fused_embedding_layernorm_dropout_bwd_unpad(
191+
double p,
192+
long pad_id,
193+
std::vector<at::Tensor> inputs) {
194+
GlobalPass _gp(BWD);
195+
if (inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) {
196+
typedef float T;
197+
typedef float ET;
198+
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
199+
} else if (
200+
inputs[0].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) {
201+
typedef bfloat16 T;
202+
typedef float ET;
203+
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
204+
} else if (
205+
inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) {
206+
typedef float T;
207+
typedef bfloat16 ET;
208+
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
209+
} else if (
210+
inputs[0].dtype() == at::kBFloat16 &&
211+
inputs[6].dtype() == at::kBFloat16) {
212+
typedef bfloat16 T;
213+
typedef bfloat16 ET;
214+
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
215+
} else {
216+
PCL_ASSERT(0, "Should not come here\n");
217+
}
218+
}
219+
} // namespace tpp
220+
} // namespace torch_ipex
221+
namespace {
222+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
223+
m.def(
224+
torch::schema(
225+
"torch_ipex::fused_self_attention_fwd_unpad(float p, Tensor[] inputs, bool training) -> Tensor[]",
226+
c10::AliasAnalysisKind::PURE_FUNCTION),
227+
torch_ipex::tpp::fused_self_attention_fwd_unpad);
228+
229+
m.def(
230+
torch::schema(
231+
"torch_ipex::fused_self_attention_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]",
232+
c10::AliasAnalysisKind::PURE_FUNCTION),
233+
torch_ipex::tpp::fused_self_attention_bwd_unpad);
234+
235+
m.def(
236+
torch::schema(
237+
"torch_ipex::fused_dense_dropout_layernorm_fwd_unpad(float p, float eps, Tensor[] inputs, bool training) -> Tensor[]",
238+
c10::AliasAnalysisKind::PURE_FUNCTION),
239+
torch_ipex::tpp::fused_dense_dropout_layernorm_fwd_unpad);
240+
241+
m.def(
242+
torch::schema(
243+
"torch_ipex::fused_dense_dropout_layernorm_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]",
244+
c10::AliasAnalysisKind::PURE_FUNCTION),
245+
torch_ipex::tpp::fused_dense_dropout_layernorm_bwd_unpad);
246+
247+
m.def(
248+
torch::schema(
249+
"torch_ipex::fused_dense_gelu_fwd_unpad(Tensor t_in, Tensor t_wt, Tensor "
250+
"t_bias, bool training)->Tensor[] ",
251+
c10::AliasAnalysisKind::PURE_FUNCTION),
252+
torch_ipex::tpp::fused_dense_gelu_fwd_unpad);
253+
254+
m.def(
255+
torch::schema(
256+
"torch_ipex::fused_dense_gelu_bwd_unpad(Tensor t_grad_out, Tensor t_gelu_in,"
257+
"Tensor t_in, Tensor t_wt) -> Tensor[]",
258+
c10::AliasAnalysisKind::PURE_FUNCTION),
259+
torch_ipex::tpp::fused_dense_gelu_bwd_unpad);
260+
261+
m.def(
262+
torch::schema(
263+
"torch_ipex::fused_embedding_layernorm_dropout_fwd_unpad(float p, float "
264+
"eps, int H, int pad_id, Tensor(a!)[] inputs, bool training) ->"
265+
"Tensor[]",
266+
c10::AliasAnalysisKind::PURE_FUNCTION),
267+
torch_ipex::tpp::fused_embedding_layernorm_dropout_fwd_unpad);
268+
269+
m.def(
270+
torch::schema(
271+
"torch_ipex::fused_embedding_layernorm_dropout_bwd_unpad(float p, int "
272+
"pad_id, Tensor(a!)[] inputs)->Tensor[] ",
273+
c10::AliasAnalysisKind::PURE_FUNCTION),
274+
torch_ipex::tpp::fused_embedding_layernorm_dropout_bwd_unpad);
275+
}
276+
} // namespace

0 commit comments

Comments
 (0)