Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio Martinez committed Apr 5, 2021
1 parent b838547 commit f7bc5fe
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 66 deletions.
21 changes: 12 additions & 9 deletions tensorflow_quantum/python/differentiators/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,34 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
# Transpose to correct shape,
# [n_programs, n_symbols, n_param_gates, n_shifts],
# then reshape to the correct batch size
batch_programs = tf.reshape(
tf.transpose(new_programs, [1, 0, 2, 3]), [n_programs, m_tile])
batch_programs = tf.reshape(tf.transpose(new_programs, [1, 0, 2, 3]),
[n_programs, m_tile])
batch_weights = tf.reshape(
tf.transpose(weights, [1, 0, 2, 3]),
[n_programs, n_symbols, n_param_gates * n_shifts])
shifts = tf.reshape(
tf.transpose(shifts, [1, 0, 2, 3]), [n_programs, m_tile, 1])
shifts = tf.reshape(tf.transpose(shifts, [1, 0, 2, 3]),
[n_programs, m_tile, 1])

# Append impurity symbol into symbol name
new_symbol_names = tf.concat([
symbol_names,
tf.constant([parameter_shift_util._PARAMETER_IMPURITY_NAME])], 0)
tf.constant([parameter_shift_util._PARAMETER_IMPURITY_NAME])
], 0)

# Symbol values are the input symbol values, tiled according to
# `batch_programs`, with the shift values appended.
tiled_symbol_values = tf.tile(
tf.expand_dims(symbol_values, 1), [1, m_tile, 1])
tiled_symbol_values = tf.tile(tf.expand_dims(symbol_values, 1),
[1, m_tile, 1])
batch_symbol_values = tf.concat([tiled_symbol_values, shifts], 2)

single_program_mapper = tf.reshape(
tf.range(n_symbols * n_param_gates * n_shifts),
[n_symbols, n_param_gates * n_shifts])
batch_mapper = tf.tile(tf.expand_dims(single_program_mapper, 0), [n_programs, 1, 1])
batch_mapper = tf.tile(tf.expand_dims(single_program_mapper, 0),
[n_programs, 1, 1])

return (batch_programs, new_symbol_names, batch_symbol_values, batch_weights, batch_mapper)
return (batch_programs, new_symbol_names, batch_symbol_values,
batch_weights, batch_mapper)

@differentiator.catch_empty_inputs
@tf.function
Expand Down
107 changes: 50 additions & 57 deletions tensorflow_quantum/python/differentiators/parameter_shift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ def test_get_gradient_circuits(self):
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(1, 2)
input_programs = util.convert_to_tensor([
cirq.Circuit(cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.Y(q1)**symbols[1]),
])
input_symbol_names = tf.constant([str(s) for s in symbols])
Expand All @@ -110,24 +111,30 @@ def test_get_gradient_circuits(self):
impurity_symbol_name = "_param_shift"
impurity_symbol = sympy.Symbol(impurity_symbol_name)
expected_batch_programs_0 = util.convert_to_tensor([
cirq.Circuit(cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(
cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(),
cirq.Circuit()
])
Expand All @@ -140,56 +147,44 @@ def test_get_gradient_circuits(self):
cirq.Circuit(cirq.Y(q1)**impurity_symbol),
cirq.Circuit(),
cirq.Circuit()
])
expected_batch_programs = tf.stack([expected_batch_programs_0,
expected_batch_programs_1])
])
expected_batch_programs = tf.stack(
[expected_batch_programs_0, expected_batch_programs_1])

# The new symbols are the old ones, with an extra used for shifting.
expected_new_symbol_names = tf.concat([
input_symbol_names, tf.constant([impurity_symbol_name])], 0)
expected_new_symbol_names = tf.concat(
[input_symbol_names,
tf.constant([impurity_symbol_name])], 0)

# The batch symbol values are the input symbol values, tiled and with
# shifted values appended. Locations that have empty programs should
# also have zero for the shift.
# The shifted values are the original value plus 1/2 divided by the
# `exponent_scalar` of the gate.
expected_batch_symbol_values = tf.constant(
[[[1.5, -2.7, 1.5 + 0.5],
[1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, 1.5 + 0.5],
[1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, -2.7 + np.pi/2],
[1.5, -2.7, -2.7 - np.pi/2],
[1.5, -2.7, -2.7],
[1.5, -2.7, -2.7]],
[[-0.3, 0.9, -0.3],
[-0.3, 0.9, -0.3],
[-0.3, 0.9, -0.3],
[-0.3, 0.9, -0.3],
[-0.3, 0.9, 0.9 + 0.5],
[-0.3, 0.9, 0.9 - 0.5],
[-0.3, 0.9, 0.9],
[-0.3, 0.9, 0.9]]])
[[[1.5, -2.7, 1.5 + 0.5], [1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, 1.5 + 0.5], [1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, -2.7 + np.pi / 2], [1.5, -2.7, -2.7 - np.pi / 2],
[1.5, -2.7, -2.7], [1.5, -2.7, -2.7]],
[[-0.3, 0.9, -0.3], [-0.3, 0.9, -0.3], [-0.3, 0.9, -0.3],
[-0.3, 0.9, -0.3], [-0.3, 0.9, 0.9 + 0.5], [-0.3, 0.9, 0.9 - 0.5],
[-0.3, 0.9, 0.9], [-0.3, 0.9, 0.9]]])

# Empty program locations are given zero weight.
expected_batch_weights = tf.constant(
[[[np.pi/2, -np.pi/2, np.pi/2, -np.pi/2],
[[[np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2],
[0.5, -0.5, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0],
[np.pi/2, -np.pi/2, 0.0, 0.0]]])
[[0.0, 0.0, 0.0, 0.0], [np.pi / 2, -np.pi / 2, 0.0, 0.0]]])

expected_batch_mapper = tf.constant(
[[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]]])
expected_batch_mapper = tf.constant([[[0, 1, 2, 3], [4, 5, 6, 7]],
[[0, 1, 2, 3], [4, 5, 6, 7]]])

(test_batch_programs, test_new_symbol_names, test_batch_symbol_values,
test_batch_weights, test_batch_mapper) = diff.get_gradient_circuits(
input_programs, input_symbol_names, input_symbol_values)
for i in range(tf.shape(input_programs)[0]):
self.assertAllEqual(util.from_tensor(expected_batch_programs[i]),
util.from_tensor(test_batch_programs[i]))
self.assertAllEqual(util.from_tensor(expected_batch_programs[i]),
util.from_tensor(test_batch_programs[i]))
self.assertAllEqual(expected_new_symbol_names, test_new_symbol_names)
self.assertAllClose(expected_batch_symbol_values,
test_batch_symbol_values,
Expand All @@ -203,9 +198,7 @@ def test_get_gradient_circuits(self):
list(
util.kwargs_cartesian_product(
**{
'differentiator': [
parameter_shift.ParameterShift(),
],
'differentiator': [parameter_shift.ParameterShift(),],
'n_qubits': [5],
'n_programs': [3],
'n_ops': [3],
Expand Down Expand Up @@ -236,8 +229,8 @@ def test_gradient_circuits_grad_comparison(self, differentiator, n_qubits,
ops_tensor = util.convert_to_tensor(psums)

# Get gradients using expectations of gradient circuits.
(batch_programs, new_symbol_names, batch_symbol_values,
batch_weights, batch_mapper) = differentiator.get_gradient_circuits(
(batch_programs, new_symbol_names, batch_symbol_values, batch_weights,
batch_mapper) = differentiator.get_gradient_circuits(
programs, symbol_names_tensor, symbol_values_tensor)
analytic_op = circuit_execution_ops.get_expectation_op()
batch_pauli_sums = tf.tile(tf.expand_dims(ops_tensor, 1),
Expand Down

0 comments on commit f7bc5fe

Please sign in to comment.