forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Seryilmaz/fused dropout softmax (NVIDIA#985)
* fuse dropout into softmax in fprop for additive mask case
- Loading branch information
Showing
9 changed files
with
1,019 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#pragma once | ||
//Philox CUDA. | ||
|
||
class Philox { | ||
public: | ||
__device__ inline Philox(unsigned long long seed, | ||
unsigned long long subsequence, | ||
unsigned long long offset) { | ||
key.x = (unsigned int)seed; | ||
key.y = (unsigned int)(seed >> 32); | ||
counter = make_uint4(0, 0, 0, 0); | ||
counter.z = (unsigned int)(subsequence); | ||
counter.w = (unsigned int)(subsequence >> 32); | ||
STATE = 0; | ||
incr_n(offset / 4); | ||
} | ||
__device__ inline uint4 operator()() { | ||
if(STATE == 0) { | ||
uint4 counter_ = counter; | ||
uint2 key_ = key; | ||
//7-round philox | ||
for(int i = 0; i < 6; i++) { | ||
counter_ = single_round(counter_, key_); | ||
key_.x += (kPhilox10A); key_.y += (kPhilox10B); | ||
} | ||
output = single_round(counter_, key_); | ||
incr(); | ||
} | ||
//return a float4 directly | ||
//unsigned long ret; | ||
//switch(STATE) { | ||
// case 0: ret = output.x; break; | ||
// case 1: ret = output.y; break; | ||
// case 2: ret = output.z; break; | ||
// case 3: ret = output.w; break; | ||
//} | ||
//STATE = (STATE + 1) % 4; | ||
return output; | ||
} | ||
private: | ||
uint4 counter; | ||
uint4 output; | ||
uint2 key; | ||
unsigned int STATE; | ||
__device__ inline void incr_n(unsigned long long n) { | ||
unsigned int nlo = (unsigned int)(n); | ||
unsigned int nhi = (unsigned int)(n >> 32); | ||
counter.x += nlo; | ||
if (counter.x < nlo) | ||
nhi++; | ||
counter.y += nhi; | ||
if (nhi <= counter.y) | ||
return; | ||
if (++counter.z) | ||
return; | ||
++counter.w; | ||
} | ||
__device__ inline void incr() { | ||
if (++counter.x) | ||
return; | ||
if (++counter.y) | ||
return; | ||
if (++counter.z) | ||
return; | ||
++counter.w; | ||
} | ||
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, | ||
unsigned int *result_high) { | ||
*result_high = __umulhi(a, b); | ||
return a*b; | ||
} | ||
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { | ||
unsigned int hi0; | ||
unsigned int hi1; | ||
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); | ||
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); | ||
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; | ||
return ret; | ||
} | ||
static const unsigned long kPhilox10A = 0x9E3779B9; | ||
static const unsigned long kPhilox10B = 0xBB67AE85; | ||
static const unsigned long kPhiloxSA = 0xD2511F53; | ||
static const unsigned long kPhiloxSB = 0xCD9E8D57; | ||
}; | ||
// Inverse of 2^32. | ||
#define M_RAN_INVM32 2.3283064e-10f | ||
__device__ __inline__ float4 uniform4(uint4 x) { | ||
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.