Skip to content

Commit

Permalink
Added avx512 packed compare float double and better comments/renaming…
Browse files Browse the repository at this point in the history
… on comparison predicates
  • Loading branch information
gfurtadoalmeida committed Apr 15, 2020
1 parent d11192b commit 49402bd
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 58 deletions.
105 changes: 89 additions & 16 deletions book-01/Assembly.Test/avx512_packed.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "pch.h"
#include "CppUnitTest.h"
#include <random>
#define _USE_MATH_DEFINES
#include <math.h>
#include "../Assembly/asm-headers/avx.h"
Expand All @@ -14,6 +15,78 @@ namespace Assembly {
namespace AVX512 {
TEST_CLASS(Packed)
{
TEST_METHOD(Test_Compare_Double)
{
alignas(64) ZmmVal a;
alignas(64) ZmmVal b;
alignas(64) uint8_t result[8];

std::uniform_real_distribution<double> unif(1.0, 10000.0);
std::default_random_engine re;

for (size_t i = 0; i < 8; i++)
{
a.Double[i] = unif(re);
}

for (size_t i = 0; i < 8; i++)
{
b.Double[i] = unif(re);
}

Compare_Double(&a, &b, result);

const uint8_t MASK = 0x01;

for (size_t i = 0; i < 8; i++)
{
Assert::AreEqual(a.Double[i] == b.Double[i], (bool)(((result[0] >> i) & MASK) << 8));
Assert::AreEqual(a.Double[i] != b.Double[i], (bool)(((result[1] >> i) & MASK) << 8));
Assert::AreEqual(a.Double[i] < b.Double[i], (bool)(((result[2] >> i) & MASK) << 8));
Assert::AreEqual(a.Double[i] <= b.Double[i], (bool)(((result[3] >> i) & MASK) << 8));
Assert::AreEqual(a.Double[i] > b.Double[i], (bool)(((result[4] >> i) & MASK) << 8));
Assert::AreEqual(a.Double[i] >= b.Double[i], (bool)(((result[5] >> i) & MASK) << 8));
Assert::AreEqual(!(isnan(a.Double[i]) || isnan(b.Double[i])), (bool)(((result[6] >> i) & MASK) << 8));
Assert::AreEqual(isnan(a.Double[i]) || isnan(b.Double[i]), (bool)(((result[7] >> i) & MASK) << 8));
}
}

TEST_METHOD(Test_Compare_Float)
{
alignas(64) ZmmVal a;
alignas(64) ZmmVal b;
alignas(64) uint16_t result[8];

std::uniform_real_distribution<float> unif(1.0F, 10000.0F);
std::default_random_engine re;

for (size_t i = 0; i < 16; i++)
{
a.Float[i] = unif(re);
}

for (size_t i = 0; i < 16; i++)
{
b.Float[i] = unif(re);
}

Compare_Float(&a, &b, result);

const uint16_t MASK = 0x0001;

for (size_t i = 0; i < 16; i++)
{
Assert::AreEqual(a.Float[i] == b.Float[i], (bool)(((result[0] >> i) & MASK) << 16));
Assert::AreEqual(a.Float[i] != b.Float[i], (bool)(((result[1] >> i) & MASK) << 16));
Assert::AreEqual(a.Float[i] < b.Float[i], (bool)(((result[2] >> i) & MASK) << 16));
Assert::AreEqual(a.Float[i] <= b.Float[i], (bool)(((result[3] >> i) & MASK) << 16));
Assert::AreEqual(a.Float[i] > b.Float[i], (bool)(((result[4] >> i) & MASK) << 16));
Assert::AreEqual(a.Float[i] >= b.Float[i], (bool)(((result[5] >> i) & MASK) << 16));
Assert::AreEqual(!(isnan(a.Float[i]) || isnan(b.Float[i])), (bool)(((result[6] >> i) & MASK) << 16));
Assert::AreEqual(isnan(a.Float[i]) || isnan(b.Float[i]), (bool)(((result[7] >> i) & MASK) << 16));
}
}

TEST_METHOD(Test_Math_Double)
{
alignas(64) ZmmVal a;
Expand All @@ -34,14 +107,14 @@ namespace Assembly {

for (size_t i = 0; i < 8; i++)
{
Assert::AreEqual(result[0].Double[i], a.Double[i] + b.Double[i]);
Assert::AreEqual(result[1].Double[i], a.Double[i] - b.Double[i]);
Assert::AreEqual(result[2].Double[i], a.Double[i] * b.Double[i]);
Assert::AreEqual(result[3].Double[i], a.Double[i] / b.Double[i]);
Assert::AreEqual(result[4].Double[i], abs(b.Double[i]));
Assert::AreEqual(result[5].Double[i], sqrt(a.Double[i]));
Assert::AreEqual(result[6].Double[i], a.Double[i]);
Assert::AreEqual(result[7].Double[i], b.Double[i]);
Assert::AreEqual(a.Double[i] + b.Double[i], result[0].Double[i]);
Assert::AreEqual(a.Double[i] - b.Double[i], result[1].Double[i]);
Assert::AreEqual(a.Double[i] * b.Double[i], result[2].Double[i]);
Assert::AreEqual(a.Double[i] / b.Double[i], result[3].Double[i]);
Assert::AreEqual(abs(b.Double[i]), result[4].Double[i]);
Assert::AreEqual(sqrt(a.Double[i]), result[5].Double[i]);
Assert::AreEqual(a.Double[i], result[6].Double[i]);
Assert::AreEqual(b.Double[i], result[7].Double[i]);
}
}

Expand All @@ -65,14 +138,14 @@ namespace Assembly {

for (size_t i = 0; i < 8; i++)
{
Assert::AreEqual(result[0].Float[i], a.Float[i] + b.Float[i]);
Assert::AreEqual(result[1].Float[i], a.Float[i] - b.Float[i]);
Assert::AreEqual(result[2].Float[i], a.Float[i] * b.Float[i]);
Assert::AreEqual(result[3].Float[i], a.Float[i] / b.Float[i]);
Assert::AreEqual(result[4].Float[i], abs(b.Float[i]));
Assert::AreEqual(result[5].Float[i], sqrt(a.Float[i]));
Assert::AreEqual(result[6].Float[i], a.Float[i]);
Assert::AreEqual(result[7].Float[i], b.Float[i]);
Assert::AreEqual(a.Float[i] + b.Float[i], result[0].Float[i]);
Assert::AreEqual(a.Float[i] - b.Float[i], result[1].Float[i]);
Assert::AreEqual(a.Float[i] * b.Float[i], result[2].Float[i]);
Assert::AreEqual(a.Float[i] / b.Float[i], result[3].Float[i]);
Assert::AreEqual(abs(b.Float[i]), result[4].Float[i]);
Assert::AreEqual(sqrt(a.Float[i]), result[5].Float[i]);
Assert::AreEqual(a.Float[i], result[6].Float[i]);
Assert::AreEqual(b.Float[i], result[7].Float[i]);
}
}
};
Expand Down
2 changes: 2 additions & 0 deletions book-01/Assembly/Assembly.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
<MASM Include="asm\avx-2\packed\avx2_p_unpack_u32_u64.asm" />
<MASM Include="asm\avx-2\scalar\avx2_s_bit_manipulation.asm" />
<MASM Include="asm\avx-2\scalar\avx2_s_flagless_instructions.asm" />
<MASM Include="asm\avx-512\packed\avx512_p_compare_double.asm" />
<MASM Include="asm\avx-512\packed\avx512_p_compare_float.asm" />
<MASM Include="asm\avx-512\packed\avx512_p_math_double.asm" />
<MASM Include="asm\avx-512\packed\avx512_p_math_float.asm" />
<MASM Include="asm\avx-512\scalar\avx512_s_calc_conditional_sum_zero_masking.asm" />
Expand Down
6 changes: 6 additions & 0 deletions book-01/Assembly/Assembly.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@
<MASM Include="asm\avx-512\packed\avx512_p_math_float.asm">
<Filter>Source Files</Filter>
</MASM>
<MASM Include="asm\avx-512\packed\avx512_p_compare_float.asm">
<Filter>Source Files</Filter>
</MASM>
<MASM Include="asm\avx-512\packed\avx512_p_compare_double.asm">
<Filter>Source Files</Filter>
</MASM>
</ItemGroup>
<ItemGroup>
<None Include="asm\macros.inc">
Expand Down
2 changes: 2 additions & 0 deletions book-01/Assembly/asm-headers/__declarations.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ extern "C" uint64_t AVX2_Scalar_Flagless_Multiply_UInt_(uint32_t a, uint32_t b,
extern "C" void AVX2_Scalar_Flagless_Shift_UInt_(uint32_t value, uint32_t count, uint32_t results[3], uint64_t flags[4]);

// AVX-512 / Packed
extern "C" void AVX512_Packed_Compare_Double_(const ZmmVal * a, const ZmmVal * b, uint8_t result[8]);
extern "C" void AVX512_Packed_Compare_Float_(const ZmmVal * a, const ZmmVal * b, uint16_t result[8]);
extern "C" void AVX512_Packed_Math_Double_(const ZmmVal * a, const ZmmVal * b, ZmmVal result[8]);
extern "C" void AVX512_Packed_Math_Float_(const ZmmVal * a, const ZmmVal * b, ZmmVal result[8]);

Expand Down
2 changes: 2 additions & 0 deletions book-01/Assembly/asm-headers/avx512_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace Assembly {
namespace AVX512 {
namespace Packed {

__declspec(dllexport) void Compare_Double(const ZmmVal* a, const ZmmVal* b, uint8_t result[8]);
__declspec(dllexport) void Compare_Float(const ZmmVal* a, const ZmmVal* b, uint16_t result[8]);
__declspec(dllexport) void Math_Double(const ZmmVal* a, const ZmmVal* b, ZmmVal result[8]);
__declspec(dllexport) void Math_Float(const ZmmVal* a, const ZmmVal* b, ZmmVal result[8]);
}
Expand Down
10 changes: 10 additions & 0 deletions book-01/Assembly/asm-proxies/avx512_packed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ namespace Assembly {
namespace AVX512 {
namespace Packed {

void Compare_Double(const ZmmVal* a, const ZmmVal* b, uint8_t result[8])
{
AVX512_Packed_Compare_Double_(a, b, result);
}

void Compare_Float(const ZmmVal* a, const ZmmVal* b, uint16_t result[8])
{
AVX512_Packed_Compare_Float_(a, b, result);
}

void Math_Double(const ZmmVal* a, const ZmmVal* b, ZmmVal result[8])
{
AVX512_Packed_Math_Double_(a, b, result);
Expand Down
41 changes: 41 additions & 0 deletions book-01/Assembly/asm/avx-512/packed/avx512_p_compare_double.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
include comparison_predicates.inc

.code

; void AVX512_Packed_Compare_Double_(const ZmmVal * a, const ZmmVal * b, uint8_t result[8]);
AVX512_Packed_Compare_Double_ proc

vmovaps zmm0, zmmword ptr [rcx]
vmovaps zmm1, zmmword ptr [rdx]

vcmppd k1, zmm0, zmm1, CMP_EQ_OQ
kmovb byte ptr [r8], k1

vcmppd k1, zmm0, zmm1, CMP_NEQ_OQ
kmovb byte ptr [r8+1], k1

vcmppd k1, zmm0, zmm1, CMP_LT_OS
kmovb byte ptr [r8+2], k1

vcmppd k1, zmm0, zmm1, CMP_LE_OS
kmovb byte ptr [r8+3], k1

vcmppd k1, zmm0, zmm1, CMP_GT_OS
kmovb byte ptr [r8+4], k1

vcmppd k1, zmm0, zmm1, CMP_GE_OS
kmovb byte ptr [r8+5], k1

vcmppd k1, zmm0, zmm1, CMP_ORD_S
kmovb byte ptr [r8+6], k1

vcmppd k1, zmm0, zmm1, CMP_UNORD_S
kmovb byte ptr [r8+7], k1

vzeroupper
ret


AVX512_Packed_Compare_Double_ endp

end
40 changes: 40 additions & 0 deletions book-01/Assembly/asm/avx-512/packed/avx512_p_compare_float.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
include comparison_predicates.inc

.code

; void AVX512_Packed_Compare_Float_(const ZmmVal * a, const ZmmVal * b, uint16_t result[8]);
AVX512_Packed_Compare_Float_ proc

vmovaps zmm0, zmmword ptr [rcx]
vmovaps zmm1, zmmword ptr [rdx]

vcmpps k1, zmm0, zmm1, CMP_EQ_OQ
kmovw word ptr [r8], k1

vcmpps k1, zmm0, zmm1, CMP_NEQ_OQ
kmovw word ptr [r8+2], k1

vcmpps k1, zmm0, zmm1, CMP_LT_OS
kmovw word ptr [r8+4], k1

vcmpps k1, zmm0, zmm1, CMP_LE_OS
kmovw word ptr [r8+6], k1

vcmpps k1, zmm0, zmm1, CMP_GT_OS
kmovw word ptr [r8+8], k1

vcmpps k1, zmm0, zmm1, CMP_GE_OS
kmovw word ptr [r8+10], k1

vcmpps k1, zmm0, zmm1, CMP_ORD_S
kmovw word ptr [r8+12], k1

vcmpps k1, zmm0, zmm1, CMP_UNORD_S
kmovw word ptr [r8+14], k1

vzeroupper
ret

AVX512_Packed_Compare_Float_ endp

end
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ vmovsd xmm2, real8 ptr [rdx+rax*type real8] ; xmm2 = b[i]

; xmm0 = valueToIgnore
; k1[0] = 1 if a[i] != valueToIgnore
vcmpsd k1, xmm1, xmm0, CMP_NEQ
vcmpsd k1, xmm1, xmm0, CMP_NEQ_OQ

; xmm3 = (a[i] == valueToIgnore) ? 0.0 : sum(a[i], b[i])
vaddsd xmm3{k1}{z}, xmm1, xmm2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ vmovsd xmm16, xmm16, xmm3 ; Erro value

; k1[0] = 1 if radius >= 0.0
; xmm0 = radius
vcmpsd k1, xmm0, xmm5, CMP_GE
vcmpsd k1, xmm0, xmm5, CMP_GE_OS

; Area = 4 * pi * radius^2
; Volume = (area * radius) / 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ vmovaps xmm1, xmmword ptr [rdx] ; xmm1 = b
; 0x0 = false
; 0xF = true

vcmppd xmm2, xmm0, xmm1, CMP_EQ
vcmppd xmm2, xmm0, xmm1, CMP_EQ_OQ

vmovapd xmmword ptr [r8], xmm2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ jz OneByOne
vmovaps xmm0, xmmword ptr [rcx] ; xmm0 have 4 floats

; Values < 0.0F = 0.0F
vcmpps xmm1, xmm0, xmm14, CMP_LT
vcmpps xmm1, xmm0, xmm14, CMP_LT_OS
vandnps xmm2, xmm1, xmm0

; Values > 1.0F = 1.0F
vcmpps xmm3, xmm2, xmm15, CMP_GT
vcmpps xmm3, xmm2, xmm15, CMP_GT_OS
vandps xmm4, xmm3, xmm15
vandnps xmm5, xmm3, xmm2
vorps xmm6, xmm5, xmm4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ vmulps ymm3, ymm2, ymm0 ; [(4 * PI) * radius] * radius

; X < 0 = FFFFFFFFh
; X >= 0 = 00000000h
vcmpps ymm1, ymm0, ymm9, CMP_LT ; Any result < 0.0F ?
vcmpps ymm1, ymm0, ymm9, CMP_LT_OS ; Any result < 0.0F ?

vandps ymm4, ymm1, ymm8 ; area >= 0.0F = 0.0F | area < 0.0F = QNAN
vandnps ymm5, ymm1, ymm3 ; area >= 0.0F = area | area < 0.0F = 0.0F
Expand All @@ -66,7 +66,7 @@ vmovss xmm0, real4 ptr [r8 + rax]
vmulss xmm2, xmm6, xmm0
vmulss xmm3, xmm2, xmm0

vcmpss xmm1, xmm0, xmm9, CMP_LT
vcmpss xmm1, xmm0, xmm9, CMP_LT_OS

vandps xmm4, xmm1, xmm8
vandnps xmm5, xmm1, xmm3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ AVX_Scalar_Compare_VCMPSD_Double_ proc

xor r8, r8

vcmpsd xmm2, xmm0, xmm1, CMP_LT
vcmpsd xmm2, xmm0, xmm1, CMP_LT_OS
vmovq rax, xmm2
and al, 1 ; Remove unecessary bits and set RFLAGS. If comparison fails ZF will be zero.
jnz Smaller

vcmpsd xmm2, xmm0, xmm1, CMP_EQ
vcmpsd xmm2, xmm0, xmm1, CMP_EQ_OQ
vmovq rax, xmm2
and al, 1
jnz Done

vcmpsd xmm2, xmm0, xmm1, CMP_GT
vcmpsd xmm2, xmm0, xmm1, CMP_GT_OS
vmovq rax, xmm2
and al, 1
jnz Bigger
Expand Down
Loading

0 comments on commit 49402bd

Please sign in to comment.