Skip to content

Commit

Permalink
Fix cirq.unitary(np.exp(-1j * op)) (tensorflow#229)
Browse files Browse the repository at this point in the history
* Add _exponential() and remove cirq.unitary(np.exp())

* Fix format and lint
  • Loading branch information
jaeyoo authored May 5, 2020
1 parent 0202da7 commit 7dc3140
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions tensorflow_quantum/python/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _single_to_tensor(item):
return serializer.serialize_circuit(item).SerializeToString()


def _exponential(theta, op):
op_mat = cirq.unitary(op)
return np.eye(op_mat.shape[0]) * np.cos(theta) - 1j * op_mat * np.sin(theta)


BITS = list(cirq.GridQubit.rect(1, 10))


Expand Down Expand Up @@ -262,10 +267,7 @@ def test_exponential_simple(self):
for op in [cirq.X, cirq.Y, cirq.Z]:
theta = np.random.random()
circuit = util.exponential(operators=[theta * op(q)])

# TODO(jaeyoo) : remove factor 2 if cirq issue is resolved
# https://github.com/quantumlib/Cirq/issues/2710
ground_truth_unitary = cirq.unitary(np.exp(-1j * 2 * theta * op(q)))
ground_truth_unitary = _exponential(theta, op(q))
self.assertAllClose(ground_truth_unitary, cirq.unitary(circuit))

def test_exponential_identity(self):
Expand Down Expand Up @@ -298,9 +300,9 @@ def test_exponential_complex(self):
theta1 = np.random.random()
theta2 = np.random.random()
identity = cirq.PauliString({None: cirq.I})
op1 = theta1 * cirq.Z(q[1]) * cirq.Z(q[2])
op2 = theta2 * identity
circuit = util.exponential(operators=[op1, op2])
op1 = cirq.Z(q[1]) * cirq.Z(q[2])
op2 = identity
circuit = util.exponential(operators=[theta1 * op1, theta2 * op2])

result_gates = []
for moment in circuit:
Expand All @@ -319,10 +321,7 @@ def test_exponential_complex(self):
for i in range(3, 7):
self.assertEqual(result_gates[i].qubits, (q[1],))

# TODO(jaeyoo) : remove factor 2 if cirq issue is resolved
# https://github.com/quantumlib/Cirq/issues/2710
ground_truth_unitary = cirq.unitary(np.exp(-1j * 2 * op1))
ground_truth_unitary *= cirq.unitary(np.exp(-1j * op2))
ground_truth_unitary = _exponential(theta1, op1)
result_unitary = cirq.unitary(circuit)
global_phase = ground_truth_unitary[0][0] / result_unitary[0][0]
result_unitary *= global_phase
Expand Down

0 comments on commit 7dc3140

Please sign in to comment.