Skip to content

Commit

Permalink
Fix dl.clone to do a shallow copy. Also remove backend.clone() (tenso…
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov authored Feb 1, 2018
1 parent 87f1adf commit 8613dc8
Show file tree
Hide file tree
Showing 5 changed files with 1 addition and 37 deletions.
3 changes: 1 addition & 2 deletions src/math/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ export class Ops {
/** Creates a ndarray with the same values/shape as the specified ndarray. */
@operation
static clone<T extends NDArray>(x: T): T {
const newValues = util.copyTypedArray(x.dataSync(), x.dtype);
return NDArray.make(x.shape, {values: newValues}, x.dtype) as T;
return NDArray.make(x.shape, {dataId: x.dataId}, x.dtype) as T;
}

@operation
Expand Down
2 changes: 0 additions & 2 deletions src/math/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ export interface MathBackend extends NDArrayStorage, BackendTimer {
a: Array2D, b: Array2D, aOrientation: MatrixOrientation,
bOrientation: MatrixOrientation): Array2D;

clone<T extends NDArray>(ndarray: T): T;

slice1D(x: Array1D, begin: number, size: number): Array1D;
slice2D(x: Array2D, begin: [number, number], size: [number, number]): Array2D;
slice3D(x: Array3D, begin: [number, number, number], size: [
Expand Down
4 changes: 0 additions & 4 deletions src/math/backends/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ export class MathBackendCPU implements MathBackend {
}
}

clone<T extends NDArray>(x: T): T {
return NDArray.make(x.shape, {values: new Float32Array(x.dataSync())}) as T;
}

slice1D(x: Array1D, begin: number, size: number): Array1D {
const newVals = x.dataSync().slice(begin, begin + size);
return Array1D.new(newVals, x.dtype);
Expand Down
25 changes: 0 additions & 25 deletions src/math/backends/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import {ConcatProgram} from './webgl/concat_gpu';
import {Conv2DDerBiasProgram, Conv2DDerFilterProgram, Conv2DDerInputProgram} from './webgl/conv_backprop_gpu';
import {Conv2DProgram} from './webgl/conv_gpu';
import {DepthwiseConv2DProgram} from './webgl/conv_gpu_depthwise';
import {Copy2DProgram} from './webgl/copy_gpu';
import {FromPixelsProgram} from './webgl/from_pixels_gpu';
import {GatherProgram} from './webgl/gather_gpu';
import {GPGPUContext} from './webgl/gpgpu_context';
Expand Down Expand Up @@ -283,19 +282,6 @@ export class MathBackendWebGL implements MathBackend {
return this.gpgpu;
}

clone<T extends NDArray>(x: T): T {
this.throwIfNoData(x.dataId);
this.uploadToGPU(x.dataId);
const {texShape} = this.texData[x.dataId];
// Pretend the source was in logical shape that matches the texture shape.
const source = x.as2D(texShape[0], texShape[1]);
// Do the same for output.
const output = this.makeOutputArray<Array2D>(texShape, x.dtype);
this.copy2D(source, [0, 0], texShape, output, [0, 0], texShape);
// Get back to the original logical shape.
return output.reshape(x.shape) as T;
}

slice1D(x: Array1D, begin: number, size: number): Array1D {
const program = new SliceProgram([size]);
const customSetup = program.getCustomSetupFunc([begin]);
Expand Down Expand Up @@ -330,17 +316,6 @@ export class MathBackendWebGL implements MathBackend {
return this.compileAndRun(program, [x]);
}

private copy2D(
source: Array2D, sourceBeginRowCol: [number, number],
sourceSizeRowCol: [number, number], dest: Array2D,
destBeginRowCol: [number, number],
destSizeRowCol: [number, number]): void {
const program = new Copy2DProgram(sourceSizeRowCol[1], destSizeRowCol[1]);
const customSetup = program.getCustomSetupFunc(
sourceBeginRowCol, destBeginRowCol, destSizeRowCol);
this.compileAndRun(program, [source], dest, customSetup);
}

// Concats 2d tensors along axis=1. See comments in MathBackend.concat().
concat(a: Array2D, b: Array2D): Array2D {
const program = new ConcatProgram(a.shape, b.shape);
Expand Down
4 changes: 0 additions & 4 deletions src/math/backends/kernel_registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, O extends
return backend.matMul(
config.inputs.a, config.inputs.b, config.args.aOrientation,
config.args.bOrientation) as O;
} else if (kernelName === 'Clone') {
const config = inputAndArgs as UnaryNode<R>['inputAndArgs'];
return backend.clone(config.inputs.x) as O;
} else if (kernelName === 'Slice1D') {
const config = inputAndArgs as Slice1DNode['inputAndArgs'];
return backend.slice1D(
Expand Down Expand Up @@ -363,7 +360,6 @@ executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, O extends

export interface KernelConfigRegistry<R extends Rank> {
MatMul: MatMulNode;
Clone: UnaryNode<R>;
Slice1D: Slice1DNode;
Slice2D: Slice2DNode;
Slice3D: Slice3DNode;
Expand Down

0 comments on commit 8613dc8

Please sign in to comment.