Skip to content

Commit

Permalink
Modified layers to use .shape instead of len when dealing with Tensor…
Browse files Browse the repository at this point in the history
…s. (tensorflow#168)

Co-authored-by: Michael Broughton <[email protected]>
  • Loading branch information
MichaelBroughton and MichaelBroughton authored Mar 21, 2020
1 parent 260ad7e commit 3461d74
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def call(self,
if self._w is None:
# don't re-add variable.
self._w = self.add_weight(name='circuit_learnable_parameters',
shape=[len(symbol_names)],
shape=symbol_names.shape,
initializer=initializer)

symbol_values = tf.tile(tf.expand_dims(self._w, axis=0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def call(self,
if self._w is None:
# don't re-add variable.
self._w = self.add_weight(name='circuit_learnable_parameters',
shape=[len(symbol_names)],
shape=symbol_names.shape,
initializer=initializer)

symbol_values = tf.tile(tf.expand_dims(self._w, axis=0),
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_quantum/python/layers/high_level/controlled_pqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ def __init__(self,
raise TypeError("model_circuit must be a cirq.Circuit object."
" Given: ".format(model_circuit))

self._symbols = tf.constant(
list(sorted(util.get_circuit_symbols(model_circuit))))
self._symbols_list = list(
sorted(util.get_circuit_symbols(model_circuit)))
self._symbols = tf.constant([str(x) for x in self._symbols_list])
self._circuit = util.convert_to_tensor([model_circuit])

if len(self._symbols) == 0:
if len(self._symbols_list) == 0:
raise ValueError("model_circuit has no sympy.Symbols. Please "
"provide a circuit that contains symbols so "
"that their values can be trained.")
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_quantum/python/layers/high_level/pqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ def __init__(
if not isinstance(model_circuit, cirq.Circuit):
raise TypeError("model_circuit must be a cirq.Circuit object."
" Given: {}".format(model_circuit))
self._symbols = tf.constant(
list(sorted(util.get_circuit_symbols(model_circuit))))
self._symbols_list = list(
sorted(util.get_circuit_symbols(model_circuit)))
self._symbols = tf.constant([str(x) for x in self._symbols_list])
self._model_circuit = util.convert_to_tensor([model_circuit])
if len(self._symbols) == 0:
if len(self._symbols_list) == 0:
raise ValueError("model_circuit has no sympy.Symbols. Please "
"provide a circuit that contains symbols so "
"that their values can be trained.")
Expand Down Expand Up @@ -250,7 +251,7 @@ def __init__(
# Weight creation is not placed in a Build function because the number
# of weights is independent of the input shape.
self.parameters = self.add_weight('parameters',
shape=[len(self._symbols)],
shape=self._symbols.shape,
initializer=self.initializer,
regularizer=self.regularizer,
constraint=self.constraint,
Expand Down

0 comments on commit 3461d74

Please sign in to comment.