Skip to content

Commit

Permalink
Switch shader indexing from float to int (tensorflow#93)
Browse files Browse the repository at this point in the history
* switch shader indexing from float to int

* revert graph_runner_test

* self review
  • Loading branch information
dsmilkov authored Sep 6, 2017
1 parent 25814c5 commit 25f8967
Show file tree
Hide file tree
Showing 23 changed files with 458 additions and 329 deletions.
176 changes: 109 additions & 67 deletions src/math/math.ts

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion src/math/math_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,14 @@ export class NDArrayMathGPU extends NDArrayMath {

protected batchNormalization3DInternal(
x: Array3D, mean: Array3D|Array1D, variance: Array3D|Array1D,
varianceEpsilon = 0.000001, scale?: Array3D|Array1D,
varianceEpsilon: number|null, scale?: Array3D|Array1D,
offset?: Array3D|Array1D): Array3D {
const inputs = [x, mean, variance];

if (varianceEpsilon == null) {
varianceEpsilon = 0.000001;
}

let offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
Expand Down
9 changes: 4 additions & 5 deletions src/math/webgl/argminmax_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ export function getArgMinMaxSnippet(
const compOp = (op === 'min') ? '<' : '>';
return `
float getArgMinMax${texName}() {
float bestIndex = 0.0;
float bestValue = get${texName}Flat(0.0);
int bestIndex = 0;
float bestValue = get${texName}Flat(0);
for (int ii = 0; ii < ${size}; ii++) {
float i = float(ii);
for (int i = 0; i < ${size}; i++) {
float candidate = get${texName}Flat(i);
if (isNaN(candidate)) {
return candidate;
Expand All @@ -34,7 +33,7 @@ export function getArgMinMaxSnippet(
bestIndex = i;
}
}
return bestIndex;
return float(bestIndex);
}
`;
}
Expand Down
12 changes: 6 additions & 6 deletions src/math/webgl/concat3d_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ export class Concat3DProgram implements GPGPUProgram {
concat3d_util.computeConcat3DOutputShape(x1Shape, x2Shape, axis);
this.userCode = `
void main() {
vec3 coords = getOutputCoords();
float yR = coords.x;
float yC = coords.y;
float yD = coords.z;
ivec3 coords = getOutputCoords();
int yR = coords.x;
int yC = coords.y;
int yD = coords.z;
float value = 0.0;
if (${concatAxis} < ${x1Shape[axis]}.0) {
if (${concatAxis} < ${x1Shape[axis]}) {
value = getA(yR, yC, yD);
} else {
${concatAxis} -= ${x1Shape[axis]}.0;
${concatAxis} -= ${x1Shape[axis]};
value = getB(yR, yC, yD);
}
Expand Down
67 changes: 31 additions & 36 deletions src/math/webgl/conv_backprop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,26 @@ export class Conv2DDerWeightsProgram implements GPGPUProgram {
this.params = [stride, zeroPad];
this.userCode = `
void main() {
vec4 coords = getOutputCoords();
float wR = coords.x;
float wC = coords.y;
float d1 = coords.z;
float d2 = coords.w;
ivec4 coords = getOutputCoords();
int wR = coords.x;
int wC = coords.y;
int d1 = coords.z;
int d2 = coords.w;
// Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int iyR = 0; iyR < ${yNumRows}; iyR++) {
float yR = float(iyR);
float xR = wR + yR * ${stride}.0 - ${zeroPad}.0;
for (int yR = 0; yR < ${yNumRows}; yR++) {
int xR = wR + yR * ${stride} - ${zeroPad};
if (xR < 0.0 || xR >= ${xNumRows}.0) {
if (xR < 0 || xR >= ${xNumRows}) {
continue;
}
for (int iyC = 0; iyC < ${yNumCols}; iyC++) {
float yC = float(iyC);
float xC = wC + yC * ${stride}.0 - ${zeroPad}.0;
for (int yC = 0; yC < ${yNumCols}; yC++) {
int xC = wC + yC * ${stride} - ${zeroPad};
if (xC < 0.0 || xC >= ${xNumCols}.0) {
if (xC < 0 || xC >= ${xNumCols}) {
continue;
}
Expand Down Expand Up @@ -94,42 +92,41 @@ export class Conv2DTransposeProgram implements GPGPUProgram {
this.params = [pad, fSize, origStride, hasBias];

this.userCode = `
const ivec2 pads = ivec2(${pad}, ${pad});
void main() {
vec3 coords = getOutputCoords();
float yR = coords.x;
float yC = coords.y;
float d2 = coords.z;
ivec3 coords = getOutputCoords();
int d2 = coords.z;
vec2 xRCCorner = vec2(yR, yC) - vec2(${pad}.0, ${pad}.0);
float xRCorner = xRCCorner.x;
float xCCorner = xRCCorner.y;
ivec2 xRCCorner = coords.xy - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// Convolve x(?, ?, d1) with w(:, :, d2, d1) to get y(yR, yC, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int iwR = 0; iwR < ${fSize}; iwR++) {
float wR = float(iwR);
float xR = (xRCorner + wR) / ${origStride}.0;
for (int wR = 0; wR < ${fSize}; wR++) {
float xR = float(xRCorner + wR) / ${origStride}.0;
if (xR < 0.0 || xR >= ${xRows}.0 || fract(xR) > 0.0) {
continue;
}
int ixR = int(xR);
float wRPerm = ${fSize}.0 - 1.0 - wR;
int wRPerm = ${fSize} - 1 - wR;
for (int iwC = 0; iwC < ${fSize}; iwC++) {
float wC = float(iwC);
float xC = (xCCorner + wC) / ${origStride}.0;
for (int wC = 0; wC < ${fSize}; wC++) {
float xC = float(xCCorner + wC) / ${origStride}.0;
if (xC < 0.0 || xC >= ${xCols}.0 || fract(xC) > 0.0) {
continue;
}
int ixC = int(xC);
float wCPerm = ${fSize}.0 - 1.0 - wC;
int wCPerm = ${fSize} - 1 - wC;
for (int id1 = 0; id1 < ${origOutputDepth}; id1++) {
float d1 = float(id1);
float xValue = getX(xR, xC, d1);
for (int d1 = 0; d1 < ${origOutputDepth}; d1++) {
float xValue = getX(ixR, ixC, d1);
float wValue = getW(wRPerm, wCPerm, d2, d1);
dotProd += xValue * wValue;
}
Expand All @@ -153,13 +150,11 @@ export class Conv2DDerBiasProgram implements GPGPUProgram {
this.outputShape = [outputDepth];
this.userCode = `
void main() {
float d2 = getOutputCoords();
int d2 = getOutputCoords();
float derBias = 0.0;
for (int iyR = 0; iyR < ${yNumRows}; iyR++) {
float yR = float(iyR);
for (int iyC = 0; iyC < ${yNumCols}; iyC++) {
float yC = float(iyC);
for (int yR = 0; yR < ${yNumRows}; yR++) {
for (int yC = 0; yC < ${yNumCols}; yC++) {
derBias += getDy(yR, yC, d2);
}
}
Expand Down
33 changes: 15 additions & 18 deletions src/math/webgl/conv_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,35 @@ export class Conv2DProgram implements GPGPUProgram {
const xNumRows = xShape[0];
const xNumCols = xShape[1];
this.userCode = `
const ivec2 strides = ivec2(${stride}, ${stride});
const ivec2 pads = ivec2(${pad}, ${pad});
void main() {
vec3 coords = getOutputCoords();
float yR = coords.x;
float yC = coords.y;
float d2 = coords.z;
ivec3 coords = getOutputCoords();
int d2 = coords.z;
vec2 xRCCorner = vec2(yR, yC) * vec2(${stride}.0, ${stride}.0) -
vec2(${pad}.0, ${pad}.0);
float xRCorner = xRCCorner.x;
float xCCorner = xRCCorner.y;
ivec2 xRCCorner = coords.xy * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
// Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
for (int iwR = 0; iwR < ${fieldSize}; iwR++) {
float wR = float(iwR);
float xR = xRCorner + wR;
for (int wR = 0; wR < ${fieldSize}; wR++) {
int xR = xRCorner + wR;
if (xR < 0.0 || xR >= ${xNumRows}.0) {
if (xR < 0 || xR >= ${xNumRows}) {
continue;
}
for (int iwC = 0; iwC < ${fieldSize}; iwC++) {
float wC = float(iwC);
float xC = xCCorner + wC;
for (int wC = 0; wC < ${fieldSize}; wC++) {
int xC = xCCorner + wC;
if (xC < 0.0 || xC >= ${xNumCols}.0) {
if (xC < 0 || xC >= ${xNumCols}) {
continue;
}
for (int id1 = 0; id1 < ${inputDepth}; id1++) {
float d1 = float(id1);
for (int d1 = 0; d1 < ${inputDepth}; d1++) {
float xValue = getX(xR, xC, d1);
float wValue = getW(wR, wC, d1, d2);
dotProd += xValue * wValue;
Expand Down
18 changes: 8 additions & 10 deletions src/math/webgl/copy_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ export class Copy2DProgram implements GPGPUProgram {
this.outputShape = null;
this.params = [srcNumCols, destNumCols];
this.userCode = `
uniform vec2 sourceStart;
uniform vec2 destStart;
uniform ivec2 sourceStart;
uniform ivec2 destStart;
void main() {
vec2 destCoords = getOutputCoords() - destStart;
float index = dot(destCoords, vec2(${destNumCols}.0, 1.0));
vec2 sourceCoords = sourceStart + vec2(
floor(index / ${srcNumCols}.0),
mod(index, ${srcNumCols}.0)
);
ivec2 destCoords = getOutputCoords() - destStart;
int index = destCoords.x * ${destNumCols} + destCoords.y;
int r = index / ${srcNumCols};
ivec2 sourceCoords = sourceStart + ivec2(r, index - r * ${srcNumCols});
setOutput(getSource(sourceCoords.x, sourceCoords.y));
}
`;
Expand All @@ -48,9 +46,9 @@ export class Copy2DProgram implements GPGPUProgram {
gpgpu.setOutputMatrixWriteRegion(
destStart[0], destSize[0], destStart[1], destSize[1]);
const sourceStartCRLoc = gpgpu.getUniformLocation('sourceStart');
gpgpu.gl.uniform2f(sourceStartCRLoc, sourceStart[0], sourceStart[1]);
gpgpu.gl.uniform2i(sourceStartCRLoc, sourceStart[0], sourceStart[1]);
const destStartCRLoc = gpgpu.getUniformLocation('destStart');
gpgpu.gl.uniform2f(destStartCRLoc, destStart[0], destStart[1]);
gpgpu.gl.uniform2i(destStartCRLoc, destStart[0], destStart[1]);
};
}
}
10 changes: 7 additions & 3 deletions src/math/webgl/gpgpu_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ export class GPGPUContext {
if (!webgl_util.isWebGL2Enabled()) {
this.textureFloatExtension =
webgl_util.getExtensionOrThrow(this.gl, 'OES_texture_float');
this.colorBufferFloatExtension =
this.gl.getExtension('WEBGL_color_buffer_float');
} else {
this.colorBufferFloatExtension =
webgl_util.getExtensionOrThrow(this.gl, 'EXT_color_buffer_float');
}

this.loseContextExtension =
webgl_util.getExtensionOrThrow(this.gl, 'WEBGL_lose_context') as
WebGLLoseContextExtension;
this.loseContextExtension = webgl_util.getExtensionOrThrow(
this.gl, 'WEBGL_lose_context') as WebGLLoseContextExtension;
this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl);
this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl);
this.framebuffer = webgl_util.createFramebuffer(this.gl);
Expand Down Expand Up @@ -258,6 +259,9 @@ export class GPGPUContext {
this.throwIfDisposed();
webgl_util.bindColorTextureToFramebuffer(
this.gl, texture, this.framebuffer);
if (this.autoDebugValidate) {
webgl_util.validateFramebuffer(this.gl);
}
const result = downloadAndDecode();
if (this.outputTexture != null) {
webgl_util.bindColorTextureToFramebuffer(
Expand Down
22 changes: 20 additions & 2 deletions src/math/webgl/gpgpu_context_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,21 @@ describe('GPGPUContext downloadMatrixFromTexture WebGL 2.0', () => {
expect(result[0]).toBeCloseTo(0.123);
});

it('returns matrix that was uploaded', () => {
it('returns 1x1 matrix that was uploaded', () => {
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1.234]));
const result = gpgpu.downloadMatrixFromTexture(texture, 1, 1);
expect(result[0]).toBeCloseTo(1.234);
});

it('returns 2x2 matrix that was uploaded', () => {
const texture2 = gpgpu.createMatrixTexture(2, 2);
gpgpu.uploadMatrixToTexture(
texture2, 2, 2, new Float32Array([1.234, 2, 3, 4]));
const result = gpgpu.downloadMatrixFromTexture(texture2, 2, 2);
expect(result).toEqual(new Float32Array([1.234, 2, 3, 4]));
gpgpu.deleteMatrixTexture(texture2);
});

it('uses texture parameter', () => {
const texture2: WebGLTexture = gpgpu.createMatrixTexture(1, 1);
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1]));
Expand Down Expand Up @@ -84,12 +93,21 @@ describe('GPGPUContext downloadMatrixFromTexture WebGL 1.0', () => {
expect(result[0]).toBeCloseTo(0.123);
});

it('returns matrix that was uploaded', () => {
it('returns 1x1 matrix that was uploaded', () => {
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1.234]));
const result = gpgpu.downloadMatrixFromTexture(texture, 1, 1);
expect(result[0]).toBeCloseTo(1.234);
});

it('returns 2x2 matrix that was uploaded', () => {
const texture2 = gpgpu.createMatrixTexture(2, 2);
gpgpu.uploadMatrixToTexture(
texture2, 2, 2, new Float32Array([1.234, 2, 3, 4]));
const result = gpgpu.downloadMatrixFromTexture(texture2, 2, 2);
expect(result).toEqual(new Float32Array([1.234, 2, 3, 4]));
gpgpu.deleteMatrixTexture(texture2);
});

it('uses texture parameter', () => {
const texture2: WebGLTexture = gpgpu.createMatrixTexture(1, 1);
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1]));
Expand Down
2 changes: 1 addition & 1 deletion src/math/webgl/gpgpu_math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export function makeShaderKey(
const params = program.params;
const keyStart =
inputs.concat(output).map(x => x.shape + '_' + x.getTextureShapeRC());
const keyEnd = params.map(p => p.toString());
const keyEnd = params.map(String);
let key = [program.constructor.name];
key.push((program.supportsBroadcasting === true).toString());
key = key.concat(keyStart, keyEnd);
Expand Down
23 changes: 16 additions & 7 deletions src/math/webgl/gpgpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ function getTextureInternalFormat(

function getTextureFormat(
gl: WebGLRenderingContext, numChannels: number): number {
if (webgl_util.isWebGL2Enabled() && numChannels === 1) {
if (webgl_util.isWebGL2Enabled()) {
if (numChannels === 4) {
// tslint:disable-next-line:no-any
return (gl as any).RGBA;
}
// tslint:disable-next-line:no-any
return (gl as any).RED;
}
Expand Down Expand Up @@ -206,12 +210,17 @@ export function uploadMatrixToTexture(

const channelsPerTexture =
numChannels === 1 ? webgl_util.getChannelsPerTexture() : numChannels;
const unpackedArray =
new Float32Array(tex_util.getUnpackedArraySizeFromMatrixSize(
matrix.length, channelsPerTexture));
tex_util.encodeMatrixToUnpackedArray(
matrix, unpackedArray, channelsPerTexture);

let unpackedArray: Float32Array;
if (channelsPerTexture === 1) {
// No need to allocate a temporary array.
unpackedArray = matrix;
} else {
unpackedArray =
new Float32Array(tex_util.getUnpackedArraySizeFromMatrixSize(
matrix.length, channelsPerTexture));
tex_util.encodeMatrixToUnpackedArray(
matrix, unpackedArray, channelsPerTexture);
}
uploadDataToTexture(gl, texture, w, h, unpackedArray, numChannels);
}

Expand Down
Loading

0 comments on commit 25f8967

Please sign in to comment.