Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio Martinez committed Apr 2, 2021
1 parent df8e26c commit 4fdd6cd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 4 additions & 2 deletions tensorflow_quantum/python/differentiators/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
batch_programs = tf.reshape(
tf.transpose(new_programs, [1, 0, 2, 3]), [n_programs, m_tile])
weights = tf.reshape(
tf.transpose(weights, [1, 0, 2, 3]), [n_programs, m_tile])
tf.transpose(weights, [1, 0, 2, 3]), [n_programs, n_symbols, n_param_gates * n_shifts])
shifts = tf.reshape(
tf.transpose(shifts, [1, 0, 2, 3]), [n_programs, m_tile, 1])

Expand All @@ -95,7 +95,9 @@ def get_gradient_circuits(self, programs, symbol_names, symbol_values):
tf.expand_dims(symbol_values, 1), [1, m_tile, 1])
batch_symbol_values = tf.concat([tiled_symbol_values, shifts], 2)

return (batch_programs, new_symbol_names, batch_symbol_values, None)
batch_mapper = tf.tile(tiled_expectation)

return (batch_programs, new_symbol_names, batch_symbol_values, batch_mapper)

@differentiator.catch_empty_inputs
@tf.function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,15 @@ def test_get_gradient_circuits(self):
[-0.3, 0.9, 0.9],
[-0.3, 0.9, 0.9]]])

#


(test_batch_programs, test_new_symbol_names, test_batch_symbol_values,
test_batch_mapper) = diff.get_gradient_circuits(
input_programs, input_symbol_names, input_symbol_values)
for i in range(tf.shape(input_programs)[0]):
self.assertAllEqual(util.from_tensor(expected_batch_programs[i]),
util.from_tensor(test_batch_programs[i]))
for i in range(tf.shape(expected_batch_symbol_values)[0]):
for j in range(tf.shape(expected_batch_symbol_values[i])[0]):
print(f"{i,j}:")
print(expected_batch_symbol_values[i][j])
print(test_batch_symbol_values[i][j])
print()
self.assertAllEqual(expected_new_symbol_names, test_new_symbol_names)
self.assertAllClose(expected_batch_symbol_values,
test_batch_symbol_values,
Expand Down

0 comments on commit 4fdd6cd

Please sign in to comment.