diff --git a/genfft/c.ml b/genfft/c.ml index 2390aa2c4..d6e6c0238 100644 --- a/genfft/c.ml +++ b/genfft/c.ml @@ -46,6 +46,7 @@ type c_decl = and c_ast = | Asch of annotated_schedule + | Simd_leavefun | Return of c_ast | For of c_ast * c_ast * c_ast * c_ast | If of c_ast * c_ast @@ -204,6 +205,7 @@ and unparse_ast = in function | Asch a -> (unparse_annotated true a) + | Simd_leavefun -> "" (* used only in SIMD code *) | Return x -> "return " ^ unparse_ast x ^ ";" | For (a, b, c, d) -> "for (" ^ diff --git a/genfft/c.mli b/genfft/c.mli index 9a72bc731..d239451a0 100644 --- a/genfft/c.mli +++ b/genfft/c.mli @@ -42,6 +42,7 @@ type c_decl = and c_ast = | Asch of Annotate.annotated_schedule + | Simd_leavefun | Return of c_ast | For of c_ast * c_ast * c_ast * c_ast | If of c_ast * c_ast diff --git a/genfft/gen_hc2c.ml b/genfft/gen_hc2c.ml index bed6de1fb..3501eebd7 100644 --- a/genfft/gen_hc2c.ml +++ b/genfft/gen_hc2c.ml @@ -154,7 +154,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_hc2cdft.ml b/genfft/gen_hc2cdft.ml index 48439e937..057344a8d 100644 --- a/genfft/gen_hc2cdft.ml +++ b/genfft/gen_hc2cdft.ml @@ -176,7 +176,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_hc2cdft_c.ml b/genfft/gen_hc2cdft_c.ml index 9a22d9e51..3050b8e3c 100644 --- a/genfft/gen_hc2cdft_c.ml +++ b/genfft/gen_hc2cdft_c.ml @@ -188,7 +188,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_hc2hc.ml b/genfft/gen_hc2hc.ml index e5f53971a..366e5b980 100644 --- a/genfft/gen_hc2hc.ml +++ b/genfft/gen_hc2hc.ml @@ -138,7 +138,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_mdct.ml b/genfft/gen_mdct.ml index 198014702..11eb53980 100644 --- a/genfft/gen_mdct.ml +++ b/genfft/gen_mdct.ml @@ -242,7 +242,7 @@ let generate n mode = @ (if (not (window_param mode)) then [] else [Decl (C.constrealtypep, window)]) ), - add_constants (Asch annot)) + finalize_fcn (Asch annot)) in (unparse tree) ^ "\n" diff --git a/genfft/gen_notw.ml b/genfft/gen_notw.ml index 88f89c10d..c248817b5 100644 --- a/genfft/gen_notw.ml +++ b/genfft/gen_notw.ml @@ -140,7 +140,7 @@ let generate n = Decl ("INT", v); Decl ("INT", "ivs"); Decl ("INT", "ovs")]), - add_constants body) + finalize_fcn body) in let desc = Printf.sprintf diff --git a/genfft/gen_notw_c.ml b/genfft/gen_notw_c.ml index b17e6b73b..ecb6d0578 100644 --- a/genfft/gen_notw_c.ml +++ b/genfft/gen_notw_c.ml @@ -135,7 +135,7 @@ let generate n = Decl ("INT", v); Decl ("INT", "ivs"); Decl ("INT", "ovs")]), - add_constants body) + finalize_fcn body) in let desc = diff --git a/genfft/gen_r2cb.ml b/genfft/gen_r2cb.ml index a3155935e..558e06901 100644 --- a/genfft/gen_r2cb.ml +++ b/genfft/gen_r2cb.ml @@ -141,7 +141,7 @@ let generate n = Decl ("INT", v); Decl ("INT", "ivs"); Decl ("INT", "ovs")]), - add_constants body) + finalize_fcn body) in let desc = Printf.sprintf diff --git a/genfft/gen_r2cf.ml b/genfft/gen_r2cf.ml index d435138e8..60e651d7c 100644 --- a/genfft/gen_r2cf.ml +++ b/genfft/gen_r2cf.ml @@ -138,7 +138,7 @@ let generate n = Decl ("INT", v); Decl ("INT", "ivs"); Decl ("INT", "ovs")]), - add_constants body) + finalize_fcn body) in let desc = Printf.sprintf diff --git a/genfft/gen_r2r.ml b/genfft/gen_r2r.ml index e2f8a0f50..c445a7057 100644 --- a/genfft/gen_r2r.ml +++ b/genfft/gen_r2r.ml @@ -218,7 +218,7 @@ let generate n mode = else [Decl ("INT", "ivs")]) @ (if stride_fixed !uovstride then [] else [Decl ("INT", "ovs")]))), - add_constants body) + finalize_fcn body) in let desc = Printf.sprintf diff --git a/genfft/gen_twiddle.ml b/genfft/gen_twiddle.ml index c10b34297..bebf6bce7 100644 --- a/genfft/gen_twiddle.ml +++ b/genfft/gen_twiddle.ml @@ -123,7 +123,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_twiddle_c.ml b/genfft/gen_twiddle_c.ml index 4aeace005..e5ffbd1e7 100644 --- a/genfft/gen_twiddle_c.ml +++ b/genfft/gen_twiddle_c.ml @@ -127,7 +127,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_twidsq.ml b/genfft/gen_twidsq.ml index 8999131df..ad786da3b 100644 --- a/genfft/gen_twidsq.ml +++ b/genfft/gen_twidsq.ml @@ -138,7 +138,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/gen_twidsq_c.ml b/genfft/gen_twidsq_c.ml index 846f01455..f1209a0c7 100644 --- a/genfft/gen_twidsq_c.ml +++ b/genfft/gen_twidsq_c.ml @@ -148,7 +148,7 @@ let generate n = Decl ("INT", mb); Decl ("INT", me); Decl ("INT", ms)], - add_constants body) + finalize_fcn body) in let twinstr = Printf.sprintf "static const tw_instr twinstr[] = %s;\n\n" diff --git a/genfft/genutil.ml b/genfft/genutil.ml index 301e36a18..0448a7d3d 100644 --- a/genfft/genutil.ml +++ b/genfft/genutil.ml @@ -306,7 +306,7 @@ let unparse tree = else C.unparse_function tree) -let add_constants ast = +let finalize_fcn ast = let mergedecls = function C.Block (d1, [C.Block (d2, s)]) -> C.Block (d1 @ d2, s) | x -> x @@ -316,7 +316,7 @@ let add_constants ast = else C.extract_constants - in mergedecls (C.Block (extract_constants ast, [ast])) + in mergedecls (C.Block (extract_constants ast, [ast; C.Simd_leavefun])) let twinstr_to_string vl x = if !Simdmagic.simd_mode then diff --git a/genfft/simd.ml b/genfft/simd.ml index a2fa2df8f..b897dd89d 100644 --- a/genfft/simd.ml +++ b/genfft/simd.ml @@ -186,6 +186,7 @@ and unparse_ast ast = in match ast with | Asch a -> (unparse_annotated true a) | Return x -> "return " ^ unparse_ast x ^ ";" + | Simd_leavefun -> "VLEAVE();" | For (a, b, c, d) -> "for (" ^ unparse_ast a ^ "; " ^ unparse_ast b ^ "; " ^ unparse_ast c diff --git a/simd-support/simd-avx256d.h b/simd-support/simd-avx256d.h index 133c0b988..86c43e02c 100644 --- a/simd-support/simd-avx256d.h +++ b/simd-support/simd-avx256d.h @@ -223,4 +223,9 @@ static inline V BYTWJ2(const R *t, V sr) #define VFMSCONJ(b,c) VSUB(VCONJ(b),c) #define VFNMSCONJ(b,c) VSUB(c, VCONJ(b)) +/* User VZEROUPPER to avoid the penalty of switching from AVX to + SSE. See Intel Optimization Manual (April 2011, version 248966), + Section 11.3 */ +#define VLEAVE _mm256_zeroupper + #include "simd-common.h" diff --git a/simd-support/simd-sse2.h b/simd-support/simd-sse2.h index dedff130b..4251f1b69 100644 --- a/simd-support/simd-sse2.h +++ b/simd-support/simd-sse2.h @@ -208,4 +208,6 @@ static inline V BYTWJ2(const R *t, V sr) #define VFMSCONJ(b,c) VSUB(VCONJ(b),c) #define VFNMSCONJ(b,c) VSUB(c, VCONJ(b)) +#define VLEAVE() /* nothing */ + #include "simd-common.h"