Skip to content

Commit

Permalink
Use texel dtype appropriate for ArrayBufferView dtype when uploading …
Browse files Browse the repository at this point in the history
…dense matrix to texture. (tensorflow#1793)

PERF
  • Loading branch information
annxingyuan authored Jun 18, 2019
1 parent afabd67 commit cb24f29
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
19 changes: 14 additions & 5 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2471,8 +2471,8 @@ export class MathBackendWebGL implements KernelBackend {
} else {
this.canvas = null;
}
if (this.fromPixels2DContext != null
&& this.fromPixels2DContext.canvas.remove != null) {
if (this.fromPixels2DContext != null &&
this.fromPixels2DContext.canvas.remove != null) {
this.fromPixels2DContext.canvas.remove();
}
if (this.gpgpuCreatedLocally) {
Expand Down Expand Up @@ -2535,18 +2535,27 @@ export class MathBackendWebGL implements KernelBackend {

let program;
let width = texShape[1], height = texShape[0];
const isByteArray = values instanceof Uint8Array;

if (isPacked) {
[width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight(
texShape[0], texShape[1]);
program = new EncodeMatrixPackedProgram(shapeAs3D, [height, width]);
program = new EncodeMatrixPackedProgram(
shapeAs3D, [height, width], isByteArray);
} else {
program = new EncodeMatrixProgram(shapeAs3D, [height, width]);
program =
new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray);
}

const tempDenseInputHandle =
this.makeTensorHandle([height, width], dtype);
this.texData.get(tempDenseInputHandle.dataId).usage = TextureUsage.UPLOAD;
if (isByteArray) {
this.texData.get(tempDenseInputHandle.dataId).usage =
TextureUsage.PIXELS;
} else {
this.texData.get(tempDenseInputHandle.dataId).usage =
TextureUsage.UPLOAD;
}
this.gpgpu.uploadDenseMatrixToTexture(
this.getTexture(tempDenseInputHandle.dataId), width, height,
values as TypedArray);
Expand Down
13 changes: 9 additions & 4 deletions src/backends/webgl/encode_matrix_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ export class EncodeMatrixProgram implements GPGPUProgram {
userCode: string;
outputShape: number[];

constructor(outputShape: [number, number, number], texShape: [
number, number
]) {
constructor(
outputShape: [number, number, number], texShape: [number, number],
inputIsUnsignedByte = false) {
const glsl = getGlslDifferences();
const [height, width] = texShape;
this.outputShape = outputShape;

let output = `result`;
if (inputIsUnsignedByte) {
output = `floor(result * 255. + 0.5)`;
}

this.userCode = `
${shader_util.getFlatIndexFrom3D(outputShape)}
Expand Down Expand Up @@ -58,7 +63,7 @@ export class EncodeMatrixProgram implements GPGPUProgram {
result = values[3];
}
${glsl.output} = vec4(result, 0., 0., 0.);
${glsl.output} = vec4(${output}, 0., 0., 0.);
}
`;
}
Expand Down
12 changes: 8 additions & 4 deletions src/backends/webgl/encode_matrix_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ export class EncodeMatrixPackedProgram implements GPGPUProgram {
userCode: string;
outputShape: number[];

constructor(outputShape: [number, number, number], texShape: [
number, number
]) {
constructor(
outputShape: [number, number, number], texShape: [number, number],
inputIsUnsignedByte = false) {
const glsl = getGlslDifferences();
const [height, width] = texShape;
this.outputShape = outputShape;

let mainLoop = '';
let output = 'result';
if (inputIsUnsignedByte) {
output = 'floor(result * 255. + 0.5)';
}

for (let row = 0; row <= 1; row++) {
for (let col = 0; col <= 1; col++) {
Expand Down Expand Up @@ -98,7 +102,7 @@ export class EncodeMatrixPackedProgram implements GPGPUProgram {
${mainLoop}
${glsl.output} = result;
${glsl.output} = ${output};
}
`;
}
Expand Down
17 changes: 14 additions & 3 deletions src/backends/webgl/gpgpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,25 @@ export function uploadDenseMatrixToTexture(
textureConfig: TextureConfig) {
webgl_util.callAndCheck(
gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture));
const dataForUpload = new Float32Array(width * height * 4);

let dataForUpload: TypedArray, texelDataType: number, internalFormat: number;
if (data instanceof Uint8Array) {
dataForUpload = new Uint8Array(width * height * 4);
texelDataType = gl.UNSIGNED_BYTE;
internalFormat = gl.RGBA;
} else {
dataForUpload = new Float32Array(width * height * 4);
texelDataType = gl.FLOAT;
internalFormat = textureConfig.internalFormatPackedFloat;
}

dataForUpload.set(data);

webgl_util.callAndCheck(
gl, debug,
() => gl.texImage2D(
gl.TEXTURE_2D, 0, textureConfig.internalFormatPackedFloat, width,
height, 0, gl.RGBA, gl.FLOAT, dataForUpload));
gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA,
texelDataType, dataForUpload));

webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null));
}
Expand Down

0 comments on commit cb24f29

Please sign in to comment.