Skip to content

Commit

Permalink
added symbol values test
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio Martinez committed Apr 2, 2021
1 parent a689b1f commit 750ce65
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
33 changes: 14 additions & 19 deletions tensorflow_quantum/python/differentiators/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,23 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
# then reshape to the correct batch size
batch_programs = tf.reshape(
tf.transpose(new_programs, [1, 0, 2, 3]), [n_programs, m_tile])

weights = tf.transpose(weights, [0, 2, 3, 1])
shifts = tf.transpose(shifts, [0, 2, 3, 1])


# # tile up and then reshape to order ops correctly
# flat_perturbations = tf.concat([
# tf.reshape(
# tf.tile(tf.expand_dims(symbol_values, 0),
# tf.stack([n_tile, 1, 1])), [total_programs, n_symbols]),
# tf.expand_dims(flat_shifts, axis=1)
# ],
# axis=1)
weights = tf.reshape(
tf.transpose(weights, [1, 0, 2, 3]), [n_programs, m_tile])
shifts = tf.reshape(
tf.transpose(shifts, [1, 0, 2, 3]), [n_programs, m_tile])

# Append impurity symbol into symbol name
new_symbol_names = tf.concat([
expected_new_symbol_names = tf.concat([
symbol_names,
tf.expand_dims(tf.constant(
parameter_shift_util._PARAMETER_IMPURITY_NAME),
axis=0)
],
axis=0)
tf.constant([parameter_shift_util._PARAMETER_IMPURITY_NAME])], 0)

# Symbol values are the input symbol values, tiled according to
# `batch_programs`, with the shift values appended.
tiled_symbol_names = tf.tile(
tf.expand_dims(symbol_values, 1), [1, m_tile, 1])
tiled_shifts = tf.expand_dims(shifts, 1)
batch_symbol_values = tf.concat(
[tiled_symbol_names, tiled_symbol_values], 2)

return (batch_programs, new_symbol_names, None, None)

Expand Down
32 changes: 28 additions & 4 deletions tensorflow_quantum/python/differentiators/parameter_shift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,40 @@ def test_get_gradient_circuits(self):
expected_new_symbol_names = tf.concat([
input_symbol_names, tf.constant([impurity_symbol_name])], 0)

# The batch symbol values are the input symbol values, tiled and with
# shifts appended. Locations that have empty programs should also have
# zero for the shift.
# The shift values are the shifted symbol value, plus 1/2 divided by the
# `exponent_scalar` of the gates.
expected_batch_symbol_values = tf.constant(
[[[1.5, -2.7, 1.5 + 0.5],
[1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, 1.5 + 0.5],
[1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, -2.7 + 1/(2*np.pi)],
[1.5, -2.7, -2.7 - 1/(2*np.pi)],
[1.5, -2.7, 0.0],
[1.5, -2.7, 0.0]],
[[-0.3, 0.9, 0.0],
[-0.3, 0.9, 0.0],
[-0.3, 0.9, 0.0],
[-0.3, 0.9, 0.0],
[-0.3, 0.9, 0.9 + 0.5],
[-0.3, 0.9, 0.9 - 0.5],
[-0.3, 0.9, 0.0],
[-0.3, 0.9, 0.0]]])

#
(test_batch_programs, test_new_symbol_names, test_batch_symbol_values,
test_batch_mapper) = diff.get_gradient_circuits(
input_programs, input_symbol_names, input_symbol_values)
for i in range(2):
for i in range(tf.shape(input_programs)[0]):
self.assertAllEqual(util.from_tensor(expected_batch_programs[i]),
util.from_tensor(test_batch_programs[i]))
self.assertAllEqual(expected_new_symbol_names, test_new_symbol_names)
# self.assertAllClose(expected_batch_symbol_values,
# test_batch_symbol_values,
# atol=1e-6)
self.assertAllClose(expected_batch_symbol_values,
test_batch_symbol_values,
atol=1e-6)
# self.assertAllClose(expected_batch_mapper, test_batch_mapper, atol=1e-6)

# @parameterized.parameters(
Expand Down

0 comments on commit 750ce65

Please sign in to comment.