Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Support complex type in concat op (#1829)
Browse files Browse the repository at this point in the history
FEATURE

Modify concat op to support complex type. It is also necessary to support the arbitrary shape of the returned complex value in stft op.
  • Loading branch information
Lewuathe authored and Nikhil Thorat committed Aug 6, 2019
1 parent 5cc5267 commit 975e5f6
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
import {real, imag, complex} from '../../ops/complex_ops';

function mapActivation(
backend: MathBackendCPU, x: Tensor, activation: Activation,
Expand Down Expand Up @@ -380,7 +381,11 @@ export class MathBackendCPU implements KernelBackend {
}

concat(tensors: Tensor[], axis: number): Tensor {
this.assertNotComplex(tensors, 'concat');
if (tensors[0].dtype === 'complex64') {
const reals = tensors.map((t) => real(t));
const imags = tensors.map((t) => imag(t));
return complex(this.concat(reals, axis), this.concat(imags, axis));
}
const tensors2D = tensors.map(t => {
const innerSize = util.sizeFromShape(t.shape.slice(axis));
return t.as2D(-1, innerSize);
Expand Down
6 changes: 6 additions & 0 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ import * as unary_packed_op from './unaryop_packed_gpu';
import {UnaryOpPackedProgram} from './unaryop_packed_gpu';
import {UnpackProgram} from './unpack_gpu';
import * as webgl_util from './webgl_util';
import {real, imag, complex} from '../../ops/complex_ops';

type KernelInfo = {
name: string; query: Promise<number>;
Expand Down Expand Up @@ -798,6 +799,11 @@ export class MathBackendWebGL implements KernelBackend {
}

concat(tensors: Tensor[], axis: number): Tensor {
if (tensors[0].dtype === 'complex64') {
const reals = tensors.map((t) => real(t));
const imags = tensors.map((t) => imag(t));
return complex(this.concat(reals, axis), this.concat(imags, axis));
}
if (this.shouldExecuteOnCPU(tensors)) {
return this.cpuBackend.concat(tensors, axis);
}
Expand Down
9 changes: 9 additions & 0 deletions src/ops/concat_split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ function concat4d_(
function concat_<T extends Tensor>(tensors: Array<T|TensorLike>, axis = 0): T {
assert(tensors.length >= 1, () => 'Pass at least one tensor to concat');
let $tensors = convertToTensorArray(tensors, 'tensors', 'concat');
if ($tensors[0].dtype === 'complex64') {
$tensors.forEach(tensor => {
if (tensor.dtype !== 'complex64') {
throw new Error(`Cannot concatenate complex64 tensors with a tensor
with dtype ${tensor.dtype}. `);
}
});
}

axis = parseAxisParam(axis, $tensors[0].shape)[0];
const outShape = computeOutShape($tensors.map(t => t.shape), axis);
if (sizeFromShape(outShape) === 0) {
Expand Down
87 changes: 87 additions & 0 deletions src/ops/concat_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
const expected = [3, 5];
expectArraysClose(await result.data(), expected);
});

it('concat complex input', async() => {
// [1+1j, 2+2j]
const c1 = tf.complex([1, 2], [1, 2]);
// [3+3j, 4+4j]
const c2 = tf.complex([3, 4], [3, 4]);

const axis = 0;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});
});

describeWithFlags('concat2d', ALL_ENVS, () => {
Expand Down Expand Up @@ -220,6 +233,32 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
expect(res2.shape).toEqual([0, 15]);
expectArraysEqual(await res2.data(), []);
});

it('concat complex input axis=0', async() => {
// [[1+1j, 2+2j], [3+3j, 4+4j]]
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
// [[5+5j, 6+6j], [7+7j, 8+8j]]
const c2 = tf.complex([[5, 6], [7, 8]], [[5, 6], [7, 8]]);

const axis = 0;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
// [[1+1j, 2+2j], [3+3j, 4+4j]]
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
// [[5+5j, 6+6j], [7+7j, 8+8j]]
const c2 = tf.complex([[5, 6], [7, 8]], [[5, 6], [7, 8]]);

const axis = 1;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 5, 5, 6, 6, 3, 3, 4, 4, 7, 7, 8, 8];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});
});

describeWithFlags('concat3d', ALL_ENVS, () => {
Expand Down Expand Up @@ -460,6 +499,54 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
expect(values.shape).toEqual([2, 3, 1]);
expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]);
});

it('concat complex input axis=0', async() => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 0;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 1;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 2;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});
});

describeWithFlags('concat throws for non-tensors', ALL_ENVS, () => {
Expand Down

0 comments on commit 975e5f6

Please sign in to comment.