Skip to content

Commit

Permalink
Ensure gradient of tf.math.fidelity remains float32 when autographed. (
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelBroughton authored Jun 18, 2021
1 parent 534f65d commit 68e0ec3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
16 changes: 13 additions & 3 deletions tensorflow_quantum/core/ops/math_ops/fidelity_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


@tf.function
@tf.custom_gradient
def fidelity(programs, symbol_names, symbol_values, other_programs):
"""Calculate the fidelity between circuits.
Expand Down Expand Up @@ -78,7 +79,16 @@ def fidelity(programs, symbol_names, symbol_values, other_programs):
to the fidelity of `programs[i]` with `symbol_values[i]`
resolved in and `other_programs[i][j]`.
"""
ip = inner_product_op.inner_product(programs, symbol_names,
tf.cast(symbol_values, tf.float32),
f32_vals = tf.cast(symbol_values, tf.float32)
ip = inner_product_op.inner_product(programs, symbol_names, f32_vals,
other_programs)
return tf.math.abs(ip)**2

def grad(dy):
ret_zero = tf.equal(tf.size(symbol_names), 0)
inner_prod_grad = tf.cond(
ret_zero, lambda: tf.zeros_like(symbol_values, dtype=tf.float32),
lambda: tf.math.real(2. * ip * inner_product_op._inner_product_grad(
programs, symbol_names, symbol_values, other_programs, dy)))
return [None, None, inner_prod_grad, None]

return tf.math.abs(ip)**2, grad
7 changes: 7 additions & 0 deletions tensorflow_quantum/core/ops/math_ops/fidelity_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
out_arr[i][j] = np.abs(np.vdot(final_wf, internal_wf))**2

self.assertAllClose(out, out_arr, atol=1e-5)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

@parameterized.parameters([
{
Expand Down Expand Up @@ -138,6 +139,7 @@ def test_correctness_without_symbols(self, n_qubits, batch_size,
out_arr[i][j] = np.abs(np.vdot(final_wf, internal_wf))**2

self.assertAllClose(out, out_arr, atol=1e-5)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

def test_correctness_empty(self):
"""Tests the fidelity with empty circuits."""
Expand All @@ -151,6 +153,7 @@ def test_correctness_empty(self):
other_program)
expected = np.array([[1.0]], dtype=np.complex64)
self.assertAllClose(out, expected)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

qubit = cirq.GridQubit(0, 0)
non_empty_circuit = util.convert_to_tensor(
Expand Down Expand Up @@ -235,6 +238,7 @@ def test_tf_gradient_correctness_with_symbols(self, n_qubits, batch_size,
out_arr[i][k] += grad_fid

self.assertAllClose(out, out_arr, atol=1e-3)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

@parameterized.parameters([
{
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_tf_gradient_correctness_without_symbols(self, n_qubits, batch_size,
other_programs)
out = tape.gradient(ip, symbol_values)
self.assertAllClose(out, tf.zeros_like(symbol_values), atol=1e-3)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

def test_correctness_no_circuit(self):
"""Test the inner product between no circuits."""
Expand All @@ -284,6 +289,7 @@ def test_correctness_no_circuit(self):
out = fidelity_op.fidelity(empty_circuit, empty_symbols, empty_values,
other_program)
self.assertShapeEqual(np.zeros((0, 0)), out)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)

def test_tf_gradient_correctness_no_circuit(self):
"""Test the inner product grad between no circuits."""
Expand All @@ -299,6 +305,7 @@ def test_tf_gradient_correctness_no_circuit(self):
empty_values, other_program)

self.assertShapeEqual(np.zeros((0, 0)), out)
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)


if __name__ == "__main__":
Expand Down
15 changes: 12 additions & 3 deletions tensorflow_quantum/core/ops/math_ops/tfq_inner_product_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,18 @@ REGISTER_OP("TfqInnerProductGrad")
c->Dim(programs_shape, 0);
tensorflow::shape_inference::DimensionHandle output_cols =
c->Dim(symbol_names_shape, 0);
std::vector<tensorflow::shape_inference::DimensionHandle> dims = {
output_rows, output_cols};
c->set_output(0, c->MakeShape(dims));

// Use kUnknownDim instead to prevent shape inference from breaking
// @tf.custom_gradient code in fidelity_op.py. The grad function has
// an implicit data dependency on `sybmol_names` that shape infrence
// can't (and shouldn't) see. Not specifying shape prevents this break.
// std::vector<tensorflow::shape_inference::DimensionHandle> dims = {
// output_rows,
// tensorflow::shape_inference::InferenceContext::kUnknownDim};
c->set_output(
0, c->MakeShape(
{output_rows,
tensorflow::shape_inference::InferenceContext::kUnknownDim}));

return tensorflow::Status::OK();
});
Expand Down

0 comments on commit 68e0ec3

Please sign in to comment.