Skip to content

Commit

Permalink
ARM-NEON intrinsics code paths now type-safe (microsoft#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
walbourn authored Jul 31, 2020
1 parent 404c59a commit 9962628
Show file tree
Hide file tree
Showing 6 changed files with 561 additions and 539 deletions.
28 changes: 22 additions & 6 deletions Inc/DirectXMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@

#if defined(_XM_ARM_NEON_INTRINSICS_) && !defined(_XM_NO_INTRINSICS_)

#if defined(__clang__)
#if defined(__clang__) || defined(__GNUC__)
#define XM_PREFETCH( a ) __builtin_prefetch(a)
#elif defined(_MSC_VER)
#define XM_PREFETCH( a ) __prefetch(a)
Expand Down Expand Up @@ -380,9 +380,13 @@ namespace DirectX

inline operator XMVECTOR() const noexcept { return v; }
inline operator const float* () const noexcept { return f; }
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
#ifdef _XM_NO_INTRINSICS_
#elif defined(_XM_SSE_INTRINSICS_)
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
#endif
};

Expand All @@ -395,9 +399,13 @@ namespace DirectX
};

inline operator XMVECTOR() const noexcept { return v; }
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
#ifdef _XM_NO_INTRINSICS_
#elif defined(_XM_SSE_INTRINSICS_)
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
#endif
};

Expand All @@ -410,9 +418,13 @@ namespace DirectX
};

inline operator XMVECTOR() const noexcept { return v; }
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
#ifdef _XM_NO_INTRINSICS_
#elif defined(_XM_SSE_INTRINSICS_)
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
#endif
};

Expand All @@ -425,9 +437,13 @@ namespace DirectX
};

inline operator XMVECTOR() const noexcept { return v; }
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
#ifdef _XM_NO_INTRINSICS_
#elif defined(_XM_SSE_INTRINSICS_)
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
#endif
};

Expand Down Expand Up @@ -2166,7 +2182,7 @@ namespace DirectX
// Convert DivExponent into 1.0f/(1<<DivExponent)
uint32_t uScale = 0x3F800000U - (DivExponent << 23);
// Splat the scalar value (It's really a float)
vScale = vdupq_n_u32(uScale);
vScale = vreinterpretq_s32_u32(vdupq_n_u32(uScale));
// Multiply by the reciprocal (Perform a right shift by DivExponent)
vResult = vmulq_f32(vResult, reinterpret_cast<const float32x4_t*>(&vScale)[0]);
return vResult;
Expand Down
78 changes: 39 additions & 39 deletions Inc/DirectXMathConvert.inl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorIntToFloat
return Result;
#elif defined(_XM_ARM_NEON_INTRINSICS_)
float fScale = 1.0f / (float)(1U << DivExponent);
float32x4_t vResult = vcvtq_f32_s32(VInt);
float32x4_t vResult = vcvtq_f32_s32(vreinterpretq_s32_f32(VInt));
return vmulq_n_f32(vResult, fScale);
#else // _XM_SSE_INTRINSICS_
// Convert to floats
Expand Down Expand Up @@ -91,10 +91,10 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorFloatToInt
// Float to int conversion
int32x4_t vResulti = vcvtq_s32_f32(vResult);
// If there was positive overflow, set to 0x7FFFFFFF
vResult = vandq_u32(vOverflow, g_XMAbsMask);
vOverflow = vbicq_u32(vResulti, vOverflow);
vOverflow = vorrq_u32(vOverflow, vResult);
return vOverflow;
vResult = vreinterpretq_f32_u32(vandq_u32(vOverflow, g_XMAbsMask));
vOverflow = vbicq_u32(vreinterpretq_u32_s32(vResulti), vOverflow);
vOverflow = vorrq_u32(vOverflow, vreinterpretq_u32_f32(vResult));
return vreinterpretq_f32_u32(vOverflow);
#else // _XM_SSE_INTRINSICS_
XMVECTOR vResult = _mm_set_ps1(static_cast<float>(1U << MulExponent));
vResult = _mm_mul_ps(vResult, VFloat);
Expand Down Expand Up @@ -129,7 +129,7 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorUIntToFloat
return Result;
#elif defined(_XM_ARM_NEON_INTRINSICS_)
float fScale = 1.0f / (float)(1U << DivExponent);
float32x4_t vResult = vcvtq_f32_u32(VUInt);
float32x4_t vResult = vcvtq_f32_u32(vreinterpretq_u32_f32(VUInt));
return vmulq_n_f32(vResult, fScale);
#else // _XM_SSE_INTRINSICS_
// For the values that are higher than 0x7FFFFFFF, a fixup is needed
Expand Down Expand Up @@ -191,9 +191,9 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorFloatToUInt
// Float to int conversion
uint32x4_t vResulti = vcvtq_u32_f32(vResult);
// If there was overflow, set to 0xFFFFFFFFU
vResult = vbicq_u32(vResulti, vOverflow);
vOverflow = vorrq_u32(vOverflow, vResult);
return vOverflow;
vResult = vreinterpretq_f32_u32(vbicq_u32(vResulti, vOverflow));
vOverflow = vorrq_u32(vOverflow, vreinterpretq_u32_f32(vResult));
return vreinterpretq_f32_u32(vOverflow);
#else // _XM_SSE_INTRINSICS_
XMVECTOR vResult = _mm_set_ps1(static_cast<float>(1U << MulExponent));
vResult = _mm_mul_ps(vResult, VFloat);
Expand Down Expand Up @@ -240,7 +240,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt(const uint32_t* pSource) noexcept
return V;
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x4_t zero = vdupq_n_u32(0);
return vld1q_lane_u32(pSource, zero, 0);
return vreinterpretq_f32_u32(vld1q_lane_u32(pSource, zero, 0));
#elif defined(_XM_SSE_INTRINSICS_)
return _mm_load_ss(reinterpret_cast<const float*>(pSource));
#endif
Expand Down Expand Up @@ -281,7 +281,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt2(const uint32_t* pSource) noexcept
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x2_t x = vld1_u32(pSource);
uint32x2_t zero = vdup_n_u32(0);
return vcombine_u32(x, zero);
return vreinterpretq_f32_u32(vcombine_u32(x, zero));
#elif defined(_XM_SSE_INTRINSICS_)
return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
#endif
Expand All @@ -307,7 +307,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt2A(const uint32_t* pSource) noexcept
uint32x2_t x = vld1_u32(pSource);
#endif
uint32x2_t zero = vdup_n_u32(0);
return vcombine_u32(x, zero);
return vreinterpretq_f32_u32(vcombine_u32(x, zero));
#elif defined(_XM_SSE_INTRINSICS_)
return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
#endif
Expand Down Expand Up @@ -434,7 +434,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt3(const uint32_t* pSource) noexcept
uint32x2_t x = vld1_u32(pSource);
uint32x2_t zero = vdup_n_u32(0);
uint32x2_t y = vld1_lane_u32(pSource + 2, zero, 0);
return vcombine_u32(x, y);
return vreinterpretq_f32_u32(vcombine_u32(x, y));
#elif defined(_XM_SSE4_INTRINSICS_)
__m128 xy = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
__m128 z = _mm_load_ss(reinterpret_cast<const float*>(pSource + 2));
Expand Down Expand Up @@ -466,7 +466,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt3A(const uint32_t* pSource) noexcept
#else
uint32x4_t V = vld1q_u32(pSource);
#endif
return vsetq_lane_u32(0, V, 3);
return vreinterpretq_f32_u32(vsetq_lane_u32(0, V, 3));
#elif defined(_XM_SSE4_INTRINSICS_)
__m128 xy = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
__m128 z = _mm_load_ss(reinterpret_cast<const float*>(pSource + 2));
Expand Down Expand Up @@ -614,7 +614,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt4(const uint32_t* pSource) noexcept
V.vector4_u32[3] = pSource[3];
return V;
#elif defined(_XM_ARM_NEON_INTRINSICS_)
return vld1q_u32(pSource);
return vreinterpretq_f32_u32(vld1q_u32(pSource));
#elif defined(_XM_SSE_INTRINSICS_)
__m128i V = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSource));
return _mm_castsi128_ps(V);
Expand All @@ -638,7 +638,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt4A(const uint32_t* pSource) noexcept
#ifdef _MSC_VER
return vld1q_u32_ex(pSource, 128);
#else
return vld1q_u32(pSource);
return vreinterpretq_f32_u32(vld1q_u32(pSource));
#endif
#elif defined(_XM_SSE_INTRINSICS_)
__m128i V = _mm_load_si128(reinterpret_cast<const __m128i*>(pSource));
Expand Down Expand Up @@ -780,8 +780,8 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x3(const XMFLOAT3X3* pSource) noexcept
float32x4_t T = vextq_f32(v0, v1, 3);

XMMATRIX M;
M.r[0] = vandq_u32(v0, g_XMMask3);
M.r[1] = vandq_u32(T, g_XMMask3);
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T), g_XMMask3));
M.r[2] = vcombine_f32(vget_high_f32(v1), v2);
M.r[3] = g_XMIdentityR3;
return M;
Expand Down Expand Up @@ -846,9 +846,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat4x3(const XMFLOAT4X3* pSource) noexcept
float32x4_t T3 = vextq_f32(v2, v2, 1);

XMMATRIX M;
M.r[0] = vandq_u32(v0, g_XMMask3);
M.r[1] = vandq_u32(T1, g_XMMask3);
M.r[2] = vandq_u32(T2, g_XMMask3);
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
return M;
#elif defined(_XM_SSE_INTRINSICS_)
Expand Down Expand Up @@ -930,9 +930,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat4x3A(const XMFLOAT4X3A* pSource) noexcept
float32x4_t T3 = vextq_f32(v2, v2, 1);

XMMATRIX M;
M.r[0] = vandq_u32(v0, g_XMMask3);
M.r[1] = vandq_u32(T1, g_XMMask3);
M.r[2] = vandq_u32(T2, g_XMMask3);
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
return M;
#elif defined(_XM_SSE_INTRINSICS_)
Expand Down Expand Up @@ -1012,9 +1012,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x4(const XMFLOAT3X4* pSource) noexcept
float32x4_t T3 = vcombine_f32(vTemp0.val[3], rh);

XMMATRIX M = {};
M.r[0] = vandq_u32(T0, g_XMMask3);
M.r[1] = vandq_u32(T1, g_XMMask3);
M.r[2] = vandq_u32(T2, g_XMMask3);
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T0), g_XMMask3));
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
return M;
#elif defined(_XM_SSE_INTRINSICS_)
Expand Down Expand Up @@ -1096,9 +1096,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x4A(const XMFLOAT3X4A* pSource) noexcept
float32x4_t T3 = vcombine_f32(vTemp0.val[3], rh);

XMMATRIX M = {};
M.r[0] = vandq_u32(T0, g_XMMask3);
M.r[1] = vandq_u32(T1, g_XMMask3);
M.r[2] = vandq_u32(T2, g_XMMask3);
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T0), g_XMMask3));
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
return M;
#elif defined(_XM_SSE_INTRINSICS_)
Expand Down Expand Up @@ -1283,7 +1283,7 @@ inline void XM_CALLCONV XMStoreInt2
pDestination[0] = V.vector4_u32[0];
pDestination[1] = V.vector4_u32[1];
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x2_t VL = vget_low_u32(V);
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
vst1_u32(pDestination, VL);
#elif defined(_XM_SSE_INTRINSICS_)
_mm_store_sd(reinterpret_cast<double*>(pDestination), _mm_castps_pd(V));
Expand All @@ -1304,7 +1304,7 @@ inline void XM_CALLCONV XMStoreInt2A
pDestination[0] = V.vector4_u32[0];
pDestination[1] = V.vector4_u32[1];
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x2_t VL = vget_low_u32(V);
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
#ifdef _MSC_VER
vst1_u32_ex(pDestination, VL, 64);
#else
Expand Down Expand Up @@ -1373,9 +1373,9 @@ inline void XM_CALLCONV XMStoreSInt2
pDestination->x = static_cast<int32_t>(V.vector4_f32[0]);
pDestination->y = static_cast<int32_t>(V.vector4_f32[1]);
#elif defined(_XM_ARM_NEON_INTRINSICS_)
int32x2_t v = vget_low_s32(V);
v = vcvt_s32_f32(v);
vst1_s32(reinterpret_cast<int32_t*>(pDestination), v);
float32x2_t v = vget_low_f32(V);
int32x2_t iv = vcvt_s32_f32(v);
vst1_s32(reinterpret_cast<int32_t*>(pDestination), iv);
#elif defined(_XM_SSE_INTRINSICS_)
// In case of positive overflow, detect it
XMVECTOR vOverflow = _mm_cmpgt_ps(V, g_XMMaxInt);
Expand Down Expand Up @@ -1443,7 +1443,7 @@ inline void XM_CALLCONV XMStoreInt3
pDestination[1] = V.vector4_u32[1];
pDestination[2] = V.vector4_u32[2];
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x2_t VL = vget_low_u32(V);
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
vst1_u32(pDestination, VL);
vst1q_lane_u32(pDestination + 2, *reinterpret_cast<const uint32x4_t*>(&V), 2);
#elif defined(_XM_SSE_INTRINSICS_)
Expand All @@ -1468,7 +1468,7 @@ inline void XM_CALLCONV XMStoreInt3A
pDestination[1] = V.vector4_u32[1];
pDestination[2] = V.vector4_u32[2];
#elif defined(_XM_ARM_NEON_INTRINSICS_)
uint32x2_t VL = vget_low_u32(V);
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
#ifdef _MSC_VER
vst1_u32_ex(pDestination, VL, 64);
#else
Expand Down Expand Up @@ -1634,7 +1634,7 @@ inline void XM_CALLCONV XMStoreInt4
pDestination[2] = V.vector4_u32[2];
pDestination[3] = V.vector4_u32[3];
#elif defined(_XM_ARM_NEON_INTRINSICS_)
vst1q_u32(pDestination, V);
vst1q_u32(pDestination, vreinterpretq_u32_f32(V));
#elif defined(_XM_SSE_INTRINSICS_)
_mm_storeu_si128(reinterpret_cast<__m128i*>(pDestination), _mm_castps_si128(V));
#endif
Expand All @@ -1659,7 +1659,7 @@ inline void XM_CALLCONV XMStoreInt4A
#ifdef _MSC_VER
vst1q_u32_ex(pDestination, V, 128);
#else
vst1q_u32(pDestination, V);
vst1q_u32(pDestination, vreinterpretq_u32_f32(V));
#endif
#elif defined(_XM_SSE_INTRINSICS_)
_mm_store_si128(reinterpret_cast<__m128i*>(pDestination), _mm_castps_si128(V));
Expand Down
Loading

0 comments on commit 9962628

Please sign in to comment.