Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jul 26, 2022
1 parent dd4c484 commit cc46344
Showing 1 changed file with 54 additions and 54 deletions.
108 changes: 54 additions & 54 deletions freqencoder/src/freqencoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ inline constexpr __device__ float PI() { return 3.141592653589793f; }

template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
return (val + divisor - 1) / divisor;
}

// inputs: [B, D]
Expand All @@ -32,65 +32,65 @@ __global__ void kernel_freq(
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * outputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;

// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;

// locate
inputs += b * D;
outputs += t;

// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;

// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;

// locate
inputs += b * D;
outputs += t;

// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
}

// grad: [B, C], C = D + D * deg * 2
// outputs: [B, C]
// grad_inputs: [B, D]
__global__ void kernel_freq_backward(
const float * __restrict__ grad,
const float * __restrict__ outputs,
const float * __restrict__ outputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * grad_inputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;

const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;

// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;

// register
float result = grad[d];
grad += D;
outputs += D;

for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}

// write
grad_inputs[0] = result;
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;

const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;

// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;

// register
float result = grad[d];
grad += D;
outputs += D;

for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}

// write
grad_inputs[0] = result;
}


Expand All @@ -104,9 +104,9 @@ void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D,
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);

static constexpr uint32_t N_THREADS = 128;
static constexpr uint32_t N_THREADS = 128;

kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
}


Expand All @@ -123,7 +123,7 @@ void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B,
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(grad_inputs);

static constexpr uint32_t N_THREADS = 128;
static constexpr uint32_t N_THREADS = 128;

kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
}

0 comments on commit cc46344

Please sign in to comment.