Skip to content

Commit

Permalink
Upgrade sympy to 1.5 (tensorflow#262)
Browse files Browse the repository at this point in the history
* Upgrade sympy to 1.5

* fix lint in serializer_test.py
  • Loading branch information
MichaelBroughton authored Jun 15, 2020
1 parent 19cd38e commit 4b7f6ae
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 48 deletions.
2 changes: 1 addition & 1 deletion release/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def finalize_options(self):
self.install_lib = self.install_platlib


REQUIRED_PACKAGES = ['cirq == 0.8.0', 'pathos == 0.2.5', 'sympy == 1.4']
REQUIRED_PACKAGES = ['cirq == 0.8.0', 'pathos == 0.2.5', 'sympy == 1.5']
CUR_VERSION = '0.4.0'


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cirq==0.8.0
sympy==1.4
sympy==1.5
nbconvert==5.6.1
nbformat==4.4.0
pylint==2.4.4
Expand Down
41 changes: 26 additions & 15 deletions tensorflow_quantum/core/ops/tfq_ps_util_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,33 +427,44 @@ def test_weight_coefficient(self):
"""Test that scalar multiples of trivial case work."""
bit = cirq.GridQubit(0, 0)
circuit = cirq.Circuit(
cirq.X(bit)**(sympy.Symbol('alpha') * 2.0),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.0),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.0),
cirq.X(bit)**(sympy.Symbol('alpha') * 2.4),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.4),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.4),
)
inputs = util.convert_to_tensor([circuit])
symbols = tf.convert_to_tensor(['alpha'])
new = tf.convert_to_tensor(['new'])
res = tfq_ps_util_ops.tfq_ps_symbol_replace(inputs, symbols, new)
output = util.from_tensor(res)
correct_00 = cirq.Circuit(
cirq.X(bit)**(sympy.Symbol('new') * 2.0),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.0),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.0),
cirq.X(bit)**(sympy.Symbol('new') * 2.4),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.4),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.4),
)
correct_01 = cirq.Circuit(
cirq.X(bit)**(sympy.Symbol('alpha') * 2.0),
cirq.Y(bit)**(sympy.Symbol('new') * 3.0),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.0),
cirq.X(bit)**(sympy.Symbol('alpha') * 2.4),
cirq.Y(bit)**(sympy.Symbol('new') * 3.4),
cirq.Z(bit)**(sympy.Symbol('alpha') * 4.4),
)
correct_02 = cirq.Circuit(
cirq.X(bit)**(sympy.Symbol('alpha') * 2.0),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.0),
cirq.Z(bit)**(sympy.Symbol('new') * 4.0),
cirq.X(bit)**(sympy.Symbol('alpha') * 2.4),
cirq.Y(bit)**(sympy.Symbol('alpha') * 3.4),
cirq.Z(bit)**(sympy.Symbol('new') * 4.4),
)
self.assertEqual(correct_00, output[0][0][0])
self.assertEqual(correct_01, output[0][0][1])
self.assertEqual(correct_02, output[0][0][2])
for i, c in enumerate([correct_00, correct_01, correct_02]):
u1 = cirq.unitary(
cirq.resolve_parameters(c,
param_resolver={
'alpha': 1.23,
'new': 4.56
}))
u2 = cirq.unitary(
cirq.resolve_parameters(output[0][0][i],
param_resolver={
'alpha': 1.23,
'new': 4.56
}))
self.assertTrue(cirq.approx_eq(u1, u2, atol=1e-5))

def test_simple_pad(self):
"""Test simple padding."""
Expand Down
62 changes: 31 additions & 31 deletions tensorflow_quantum/core/serialize/serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def _get_valid_circuit_proto_pairs():
_build_gate_proto("HP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0'])),
(cirq.Circuit(cirq.HPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0)),
(cirq.Circuit(cirq.HPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0)),
_build_gate_proto("HP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0'])),
['alpha', 3.1, 0.0], ['0_0'])),
(cirq.Circuit(cirq.H(q0)),
_build_gate_proto("HP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -119,10 +119,10 @@ def _get_valid_circuit_proto_pairs():
_build_gate_proto("XP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0'])),
(cirq.Circuit(cirq.XPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0)),
(cirq.Circuit(cirq.XPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0)),
_build_gate_proto("XP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0'])),
['alpha', 3.1, 0.0], ['0_0'])),
(cirq.Circuit(cirq.X(q0)),
_build_gate_proto("XP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -137,10 +137,10 @@ def _get_valid_circuit_proto_pairs():
_build_gate_proto("YP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0'])),
(cirq.Circuit(cirq.YPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0)),
(cirq.Circuit(cirq.YPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0)),
_build_gate_proto("YP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0'])),
['alpha', 3.1, 0.0], ['0_0'])),
(cirq.Circuit(cirq.Y(q0)),
_build_gate_proto("YP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -155,10 +155,10 @@ def _get_valid_circuit_proto_pairs():
_build_gate_proto("ZP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0'])),
(cirq.Circuit(cirq.ZPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0)),
(cirq.Circuit(cirq.ZPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0)),
_build_gate_proto("ZP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0'])),
['alpha', 3.1, 0.0], ['0_0'])),
(cirq.Circuit(cirq.Z(q0)),
_build_gate_proto("ZP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -174,10 +174,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.XXPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.XXPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("XXP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.XX(q0, q1)),
_build_gate_proto("XXP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -193,10 +193,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.YYPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.YYPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("YYP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.YY(q0, q1)),
_build_gate_proto("YYP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -212,10 +212,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.ZZPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.ZZPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("ZZP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.ZZ(q0, q1)),
_build_gate_proto("ZZP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -231,10 +231,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.CZPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.CZPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("CZP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.CZ(q0, q1)),
_build_gate_proto("CZP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -250,10 +250,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.CNotPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.CNotPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("CNP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.CNOT(q0, q1)),
_build_gate_proto("CNP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -269,10 +269,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.SwapPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.SwapPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("SP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.SWAP(q0, q1)),
_build_gate_proto("SP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -289,10 +289,10 @@ def _get_valid_circuit_proto_pairs():
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 1.0, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.ISwapPowGate(exponent=3.0 * sympy.Symbol('alpha'))(q0, q1)),
cirq.ISwapPowGate(exponent=3.1 * sympy.Symbol('alpha'))(q0, q1)),
_build_gate_proto("ISP",
['exponent', 'exponent_scalar', 'global_shift'],
['alpha', 3.0, 0.0], ['0_0', '0_1'])),
['alpha', 3.1, 0.0], ['0_0', '0_1'])),
(cirq.Circuit(cirq.ISWAP(q0, q1)),
_build_gate_proto("ISP",
['exponent', 'exponent_scalar', 'global_shift'],
Expand All @@ -315,12 +315,12 @@ def _get_valid_circuit_proto_pairs():
'exponent_scalar', 'global_shift'
], ['alpha', 1.0, 0.3, 1.0, 0.0], ['0_0'])),
(cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=3.0 * sympy.Symbol('alpha'),
cirq.PhasedXPowGate(phase_exponent=3.1 * sympy.Symbol('alpha'),
exponent=0.3)(q0)),
_build_gate_proto("PXP", [
'phase_exponent', 'phase_exponent_scalar', 'exponent',
'exponent_scalar', 'global_shift'
], ['alpha', 3.0, 0.3, 1.0, 0.0], ['0_0'])),
], ['alpha', 3.1, 0.3, 1.0, 0.0], ['0_0'])),
(cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=0.9,
exponent=sympy.Symbol('beta'))(q0)),
Expand All @@ -330,18 +330,18 @@ def _get_valid_circuit_proto_pairs():
], [0.9, 1.0, 'beta', 1.0, 0.0], ['0_0'])),
(cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=0.9,
exponent=5.0 * sympy.Symbol('beta'))(q0)),
exponent=5.1 * sympy.Symbol('beta'))(q0)),
_build_gate_proto("PXP", [
'phase_exponent', 'phase_exponent_scalar', 'exponent',
'exponent_scalar', 'global_shift'
], [0.9, 1.0, 'beta', 5.0, 0.0], ['0_0'])),
], [0.9, 1.0, 'beta', 5.1, 0.0], ['0_0'])),
(cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=3.0 * sympy.Symbol('alpha'),
exponent=5.0 * sympy.Symbol('beta'))(q0)),
cirq.PhasedXPowGate(phase_exponent=3.1 * sympy.Symbol('alpha'),
exponent=5.1 * sympy.Symbol('beta'))(q0)),
_build_gate_proto("PXP", [
'phase_exponent', 'phase_exponent_scalar', 'exponent',
'exponent_scalar', 'global_shift'
], ['alpha', 3.0, 'beta', 5.0, 0.0], ['0_0'])),
], ['alpha', 3.1, 'beta', 5.1, 0.0], ['0_0'])),

# RX, RY, RZ with symbolization is tested in special cases as the
# string comparison of the float converted sympy.pi does not happen
Expand Down Expand Up @@ -369,11 +369,11 @@ def _get_valid_circuit_proto_pairs():
['theta', 'theta_scalar', 'phi', 'phi_scalar'],
[0.1, 1.0, 0.2, 1.0], ['0_0', '0_1'])),
(cirq.Circuit(
cirq.FSimGate(theta=2.0 * sympy.Symbol("alpha"),
cirq.FSimGate(theta=2.1 * sympy.Symbol("alpha"),
phi=1.3 * sympy.Symbol("beta"))(q0, q1)),
_build_gate_proto("FSIM",
['theta', 'theta_scalar', 'phi', 'phi_scalar'],
['alpha', 2.0, 'beta', 1.3], ['0_0', '0_1'])),
['alpha', 2.1, 'beta', 1.3], ['0_0', '0_1'])),
]

return pairs
Expand Down

0 comments on commit 4b7f6ae

Please sign in to comment.