Skip to content

Commit

Permalink
Merge branch 'kedeng/fixCrash'
Browse files Browse the repository at this point in the history
  • Loading branch information
Ke Deng committed Sep 25, 2018
2 parents 58f810f + 1489de8 commit 9165fd0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Source/ComputationNetworkLib/LinearAlgebraNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ class TimesNodeBase : public ComputationNode<ElemType>, public NumInputs<2>
}
else
{
ElementTimesNode<ElemType>::BackpropToImpl(*this, inputIndex, fr, false/*allowBroadcast*/);
ElementTimesNode<ElemType>::BackpropToImpl(*this, inputIndex, fr, true/*allowBroadcast*/);
}
return;
}
Expand Down
16 changes: 15 additions & 1 deletion bindings/python/cntk/ops/tests/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,20 @@ def test_op_transpose_times(left_operand, right_operand, device_id, precision):
_test_binary_op(precision, device_id, times_transpose,
left_operand, right_operand, expected_forward, expected_backward)

def test_times_transpose_sequence_param(device_id, precision):
dt_precision = PRECISION_TO_TYPE[precision]

from cntk import times_transpose, parameter, sequence, Value
dim = 5
num_sequences = 2
seq = [i for i in range(dim)]
identity = np.identity(dim, dtype=dt_precision)
input_data = Value.one_hot([seq]*num_sequences, dim, dtype=dt_precision)
input_var = sequence.input_variable(shape=(dim), needs_gradient=True, dtype=dt_precision)
e = parameter(shape = (dim,), init = 1, dtype=dt_precision)
z = times_transpose(e, input_var)
e_grad = z.grad({input_var : input_data}, [e, input_var])

def test_op_times_sparse_grad(device_id, precision):
dt_precision = PRECISION_TO_TYPE[precision]

Expand All @@ -401,7 +415,7 @@ def test_op_times_sparse_grad(device_id, precision):
e = parameter(shape = (dim, dim), init = identity, dtype=dt_precision)
z = reshape(times_transpose(e, times(input_var, e)), dim)
e_grad = z.grad({input_var : input_data}, [e])

assert np.allclose(e_grad, np.ones((dim,dim))*4)

def test_op_times_reduce_sequence_axis(device_id, precision):
Expand Down

0 comments on commit 9165fd0

Please sign in to comment.