Skip to content

Commit

Permalink
Merge pull request tensorflow#532 from zaqqwerty/515_ps_diff
Browse files Browse the repository at this point in the history
Gradient circuits 3/n: implementation for ParameterShift
  • Loading branch information
jaeyoo authored Apr 7, 2021
2 parents 244c97c + d474bd5 commit 9c8dccc
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 21 deletions.
23 changes: 17 additions & 6 deletions tensorflow_quantum/python/differentiators/differentiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
>>> exp_layer = tfq.layers.Expectation()
>>> batch_pauli_sums = tf.tile(
... tf.expand_dims(pauli_sums, 1),
... [1, tf.shape(batch_mapper)[2], 1])
... [1, tf.shape(batch_programs)[1], 1])
>>> n_batch_programs = tf.reduce_prod(tf.shape(batch_programs))
>>> n_symbols = tf.shape(new_symbol_names)[0]
>>> n_ops = tf.shape(pauli_sums)[1]
Expand All @@ -245,8 +245,11 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
... batch_pauli_sums, [n_batch_programs, n_ops]))
>>> batch_expectations = tf.reshape(
... batch_expectations, tf.shape(batch_pauli_sums))
>>> grad_manual = tf.reduce_sum(
... tf.einsum('ikm,imp->ikp', batch_mapper, batch_expectations), -1)
>>> batch_jacobian = tf.map_fn(
... lambda x: tf.einsum('km,kmp->kp', x[0], tf.gather(x[1], x[2])),
... (batch_weights, batch_expectations, batch_mapper),
... fn_output_signature=tf.float32)
>>> grad_manual = tf.reduce_sum(batch_jacobian, -1)
To perform the same gradient calculation automatically:
Expand Down Expand Up @@ -295,13 +298,21 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
`new_symbol_names`. Thus, at each index `i` in the first
dimension is the 2-D tensor of parameter values to fill in to
`batch_programs[i]`.
batch_mapper: 3-D `tf.Tensor` of DType `tf.float32` which defines
batch_weights: 3-D `tf.Tensor` of DType `tf.float32` which defines
how much weight to give to each program when computing the
derivatives. First dimension is the length of the input
`programs`, second dimension is the length of the input
`symbol_names`, and the third dimension is determined by the
inheriting differentiator.
batch_mapper: 3-D `tf.Tensor` of DType `tf.int32` which defines
how to map expectation values of the circuits generated by this
differentiator to the derivatives of the original circuits.
It says which indices of the returned programs are relevant for
the derivative of each symbol, for use by `tf.gather`.
The first dimension is the length of the input `programs`, the
second dimension is the length of the input `symbol_names`,
and the third dimension is the length of the second dimension of
the output `batch_programs`.
and the third dimension is the length of the last dimension of
the output `batch_weights`.
"""

@abc.abstractmethod
Expand Down
55 changes: 49 additions & 6 deletions tensorflow_quantum/python/differentiators/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,55 @@ class ParameterShift(differentiator.Differentiator):
"""

@tf.function
def get_gradient_circuits(self, programs, symbol_names, symbol_values):
"""See base class description."""
raise NotImplementedError(
"Gradient circuits are not currently available for "
"ParameterShift.")
# these get used a lot
n_symbols = tf.gather(tf.shape(symbol_names), 0)
n_programs = tf.gather(tf.shape(programs), 0)

# Assume cirq.decompose() generates gates with at most two distinct
# eigenvalues, which results in two parameter shifts.
n_shifts = 2

# These new_programs are parameter shifted.
# shapes: [n_symbols, n_programs, n_param_gates, n_shifts]
(new_programs, weights, shifts,
n_param_gates) = parameter_shift_util.parse_programs(
programs, symbol_names, symbol_values, n_symbols)

m_tile = n_shifts * n_param_gates * n_symbols

# Transpose to correct shape,
# [n_programs, n_symbols, n_param_gates, n_shifts],
# then reshape to the correct batch size
batch_programs = tf.reshape(tf.transpose(new_programs, [1, 0, 2, 3]),
[n_programs, m_tile])
batch_weights = tf.reshape(
tf.transpose(weights, [1, 0, 2, 3]),
[n_programs, n_symbols, n_param_gates * n_shifts])
shifts = tf.reshape(tf.transpose(shifts, [1, 0, 2, 3]),
[n_programs, m_tile, 1])

# Append impurity symbol into symbol name
new_symbol_names = tf.concat([
symbol_names,
tf.constant([parameter_shift_util.PARAMETER_IMPURITY_NAME])
], 0)

# Symbol values are the input symbol values, tiled according to
# `batch_programs`, with the shift values appended.
tiled_symbol_values = tf.tile(tf.expand_dims(symbol_values, 1),
[1, m_tile, 1])
batch_symbol_values = tf.concat([tiled_symbol_values, shifts], 2)

single_program_mapper = tf.reshape(
tf.range(n_symbols * n_param_gates * n_shifts),
[n_symbols, n_param_gates * n_shifts])
batch_mapper = tf.tile(tf.expand_dims(single_program_mapper, 0),
[n_programs, 1, 1])

return (batch_programs, new_symbol_names, batch_symbol_values,
batch_weights, batch_mapper)

@differentiator.catch_empty_inputs
@tf.function
Expand Down Expand Up @@ -158,7 +201,7 @@ def differentiate_analytic(self, programs, symbol_names, symbol_values,
new_symbol_names = tf.concat([
symbol_names,
tf.expand_dims(tf.constant(
parameter_shift_util._PARAMETER_IMPURITY_NAME),
parameter_shift_util.PARAMETER_IMPURITY_NAME),
axis=0)
],
axis=0)
Expand Down Expand Up @@ -304,7 +347,7 @@ def differentiate_sampled(self, programs, symbol_names, symbol_values,
new_symbol_names = tf.concat([
symbol_names,
tf.expand_dims(tf.constant(
parameter_shift_util._PARAMETER_IMPURITY_NAME),
parameter_shift_util.PARAMETER_IMPURITY_NAME),
axis=0)
],
axis=0)
Expand Down
188 changes: 181 additions & 7 deletions tensorflow_quantum/python/differentiators/parameter_shift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ def _simple_op_inputs():
class ParameterShiftTest(tf.test.TestCase, parameterized.TestCase):
"""Test the ParameterShift Differentiator will run end to end."""

def test_no_gradient_circuits(self):
"""Confirm ParameterShift differentiator has no gradient circuits."""
dif = parameter_shift.ParameterShift()
with self.assertRaisesRegex(NotImplementedError,
expected_regex="not currently available"):
_ = dif.get_gradient_circuits(None, None, None)

def test_parameter_shift_analytic(self):
"""Test if ParameterShift.differentiate_analytical doesn't crash before
running."""
Expand Down Expand Up @@ -86,6 +79,187 @@ def test_parameter_shift_sampled(self):
self.assertAllClose(expectations, true_f, atol=1e-1, rtol=1e-1)
self.assertAllClose(grads, true_g, atol=1e-1, rtol=1e-1)

def test_get_gradient_circuits(self):
"""Test that the correct objects are returned."""

diff = parameter_shift.ParameterShift()

# Circuits to differentiate.
symbols = [sympy.Symbol("s0"), sympy.Symbol("s1")]
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(1, 2)
input_programs = util.convert_to_tensor([
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(cirq.Y(q1)**symbols[1]),
])
input_symbol_names = tf.constant([str(s) for s in symbols])
input_symbol_values = tf.constant([[1.5, -2.7], [-0.3, 0.9]])

# First, for each symbol `s`, check how many times `s` appears in each
# program `p`, `n_ps`. Let `n_param_gates` be the maximum of `n_ps` over
# all symbols and programs. Then, the shape of `batch_programs` will be
# [n_programs, n_symbols * n_param_gates * n_shifts], where `n_shifts`
# is 2 because we decompose into gates with 2 eigenvalues. For row index
# `p` we have for column indices between `i * n_param_gates * n_shifts`
# and `(i + 1) * n_param_gates * n_shifts`, the first `n_pi * 2`
# programs are parameter shifted versions of `input_programs[p]` and the
# remaining programs are empty.
# Here, `n_param_gates` is 2.
impurity_symbol_name = "_impurity_for_param_shift"
impurity_symbol = sympy.Symbol(impurity_symbol_name)
expected_batch_programs_0 = util.convert_to_tensor([
cirq.Circuit(
cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**impurity_symbol,
cirq.Y(q0)**symbols[0],
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**impurity_symbol,
cirq.ry(symbols[1])(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(
cirq.X(q0)**symbols[0],
cirq.Y(q0)**symbols[0],
cirq.ry(impurity_symbol)(q1)),
cirq.Circuit(),
cirq.Circuit()
])
expected_batch_programs_1 = util.convert_to_tensor([
cirq.Circuit(),
cirq.Circuit(),
cirq.Circuit(),
cirq.Circuit(),
cirq.Circuit(cirq.Y(q1)**impurity_symbol),
cirq.Circuit(cirq.Y(q1)**impurity_symbol),
cirq.Circuit(),
cirq.Circuit()
])
expected_batch_programs = tf.stack(
[expected_batch_programs_0, expected_batch_programs_1])

# The new symbols are the old ones, with an extra used for shifting.
expected_new_symbol_names = tf.concat(
[input_symbol_names,
tf.constant([impurity_symbol_name])], 0)

# The batch symbol values are the input symbol values, tiled and with
# shifted values appended. Locations that have empty programs should
# also have zero for the shift.
# The shifted values are the original value plus 1/2 divided by the
# `exponent_scalar` of the gate.
expected_batch_symbol_values = tf.constant(
[[[1.5, -2.7, 1.5 + 0.5], [1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, 1.5 + 0.5], [1.5, -2.7, 1.5 - 0.5],
[1.5, -2.7, -2.7 + np.pi / 2], [1.5, -2.7, -2.7 - np.pi / 2],
[1.5, -2.7, -2.7], [1.5, -2.7, -2.7]],
[[-0.3, 0.9, -0.3], [-0.3, 0.9, -0.3], [-0.3, 0.9, -0.3],
[-0.3, 0.9, -0.3], [-0.3, 0.9, 0.9 + 0.5], [-0.3, 0.9, 0.9 - 0.5],
[-0.3, 0.9, 0.9], [-0.3, 0.9, 0.9]]])

# Empty program locations are given zero weight.
expected_batch_weights = tf.constant(
[[[np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2],
[0.5, -0.5, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0], [np.pi / 2, -np.pi / 2, 0.0, 0.0]]])

expected_batch_mapper = tf.constant([[[0, 1, 2, 3], [4, 5, 6, 7]],
[[0, 1, 2, 3], [4, 5, 6, 7]]])

(test_batch_programs, test_new_symbol_names, test_batch_symbol_values,
test_batch_weights, test_batch_mapper) = diff.get_gradient_circuits(
input_programs, input_symbol_names, input_symbol_values)
for i in range(tf.shape(input_programs)[0]):
self.assertAllEqual(util.from_tensor(expected_batch_programs[i]),
util.from_tensor(test_batch_programs[i]))
self.assertAllEqual(expected_new_symbol_names, test_new_symbol_names)
self.assertAllClose(expected_batch_symbol_values,
test_batch_symbol_values,
atol=1e-6)
self.assertAllClose(expected_batch_weights,
test_batch_weights,
atol=1e-6)
self.assertAllEqual(expected_batch_mapper, test_batch_mapper)

@parameterized.parameters(
list(
util.kwargs_cartesian_product(
**{
'differentiator': [parameter_shift.ParameterShift(),],
'n_qubits': [5],
'n_programs': [3],
'n_ops': [3],
'symbol_names': [['a', 'b']]
})))
def test_gradient_circuits_grad_comparison(self, differentiator, n_qubits,
n_programs, n_ops, symbol_names):
"""Test that analytic gradient agrees with the one from grad circuits"""
# Get random circuits to check.
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, resolver_batch = \
util.random_symbol_circuit_resolver_batch(
cirq.GridQubit.rect(1, n_qubits), symbol_names, n_programs)
psums = [
util.random_pauli_sums(qubits, 1, n_ops) for _ in circuit_batch
]

# Convert to tensors.
symbol_names_array = np.array(symbol_names)
symbol_values_array = np.array(
[[resolver[symbol]
for symbol in symbol_names]
for resolver in resolver_batch],
dtype=np.float32)
symbol_names_tensor = tf.convert_to_tensor(symbol_names_array)
symbol_values_tensor = tf.convert_to_tensor(symbol_values_array)
programs = util.convert_to_tensor(circuit_batch)
ops_tensor = util.convert_to_tensor(psums)

# Get gradients using expectations of gradient circuits.
(batch_programs, new_symbol_names, batch_symbol_values, batch_weights,
batch_mapper) = differentiator.get_gradient_circuits(
programs, symbol_names_tensor, symbol_values_tensor)
analytic_op = circuit_execution_ops.get_expectation_op()
batch_pauli_sums = tf.tile(tf.expand_dims(ops_tensor, 1),
[1, tf.shape(batch_programs)[1], 1])
n_batch_programs = tf.reduce_prod(tf.shape(batch_programs))
n_symbols = tf.shape(new_symbol_names)[0]
batch_expectations = analytic_op(
tf.reshape(batch_programs, [n_batch_programs]), new_symbol_names,
tf.reshape(batch_symbol_values, [n_batch_programs, n_symbols]),
tf.reshape(batch_pauli_sums, [n_batch_programs, n_ops]))
batch_expectations = tf.reshape(batch_expectations,
tf.shape(batch_pauli_sums))
batch_jacobian = tf.map_fn(
lambda x: tf.einsum('km,kmp->kp', x[0], tf.gather(x[1], x[2])),
(batch_weights, batch_expectations, batch_mapper),
fn_output_signature=tf.float32)
grad_manual = tf.reduce_sum(batch_jacobian, -1)

# Get gradients using autodiff.
differentiator.refresh()
differentiable_op = differentiator.generate_differentiable_op(
analytic_op=analytic_op)
with tf.GradientTape() as g:
g.watch(symbol_values_tensor)
exact_outputs = differentiable_op(programs, symbol_names_tensor,
symbol_values_tensor, ops_tensor)
grad_auto = g.gradient(exact_outputs, symbol_values_tensor)
self.assertAllClose(grad_manual, grad_auto)


if __name__ == "__main__":
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from tensorflow_quantum.core.ops import tfq_ps_util_ops

_PARAMETER_IMPURITY_NAME = '_param_shift'
PARAMETER_IMPURITY_NAME = '_impurity_for_param_shift'


@tf.function
Expand Down Expand Up @@ -65,7 +65,7 @@ def parse_programs(programs, symbol_names, symbol_values, n_symbols,

# Collecting doped programs with impurity sympy.Symbol from all programs
# with parameterized gates.
impurity = tf.tile(tf.convert_to_tensor([_PARAMETER_IMPURITY_NAME]),
impurity = tf.tile(tf.convert_to_tensor([PARAMETER_IMPURITY_NAME]),
[n_symbols])
symbols = tf.convert_to_tensor(symbol_names)

Expand All @@ -78,6 +78,7 @@ def parse_programs(programs, symbol_names, symbol_values, n_symbols,
n_param_gates = tf.cast(tf.gather(tf.shape(new_programs), 2),
dtype=tf.int32)

# This is a tensor of the `exponent_scalar`s of the shifted gates.
coeff = tf.expand_dims(tf.transpose(
tfq_ps_util_ops.tfq_ps_weights_from_symbols(decomposed_programs,
symbols), [1, 0, 2]),
Expand Down

0 comments on commit 9c8dccc

Please sign in to comment.