Skip to content

Commit

Permalink
Merge pull request tensorflow#544 from tensorflow/n_sample_layer
Browse files Browse the repository at this point in the history
Added noisy backend to sample layer.
  • Loading branch information
jaeyoo authored Apr 16, 2021
2 parents 4aff0df + c3592fd commit ca4d0cb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
3 changes: 3 additions & 0 deletions tensorflow_quantum/core/ops/noise/tfq_noisy_samples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ class TfqNoisySamplesOp : public tensorflow::OpKernel {
int nq = num_qubits[i];
int j = start > 0 ? offset_prefix_sum[start - 1][i] : 0;
int needed_samples = offset_prefix_sum[start][i] - j;
if (needed_samples <= 0) {
continue;
}

if (nq > largest_nq) {
largest_nq = nq;
Expand Down
1 change: 1 addition & 0 deletions tensorflow_quantum/python/layers/circuit_executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ py_library(
deps = [
":input_checks",
"//tensorflow_quantum/core/ops:circuit_execution_ops",
"//tensorflow_quantum/core/ops/noise:noisy_samples_op_py",
],
)

Expand Down
19 changes: 14 additions & 5 deletions tensorflow_quantum/python/layers/circuit_executors/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf

from tensorflow_quantum.core.ops import circuit_execution_ops
from tensorflow_quantum.core.ops.noise import noisy_samples_op
from tensorflow_quantum.python.layers.circuit_executors import input_checks


Expand Down Expand Up @@ -138,20 +139,28 @@ class Sample(tf.keras.layers.Layer):
"""

def __init__(self, backend=None, **kwargs):
def __init__(self, backend='noiseless', **kwargs):
"""Instantiate this Layer.
Create a layer that will output bitstring samples taken from either a
simulated quantum state or a real quantum computer
Args:
backend: Optional Backend to use to simulate this state. Defaults
to the native Tensorflow simulator (None), however users may
also specify a preconfigured cirq execution object to use
instead, which must inherit `cirq.Sampler`.
to the noiseless simulator. Options are {'noisy', 'noiseless'},
however users may also specify a preconfigured cirq execution
object to use instead, which must inherit `cirq.Sampler`.
"""
super().__init__(**kwargs)
self.sample_op = circuit_execution_ops.get_sampling_op(backend)
used_op = None
if backend == 'noiseless':
used_op = circuit_execution_ops.get_sampling_op(None)
elif backend == 'noisy':
used_op = noisy_samples_op.samples
else:
used_op = circuit_execution_ops.get_sampling_op(backend)

self.sample_op = used_op

def call(self,
inputs,
Expand Down
31 changes: 19 additions & 12 deletions tensorflow_quantum/python/layers/circuit_executors/sample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,20 @@ def test_sample_invalid_shape_inputs(self):
TypeError, expected_regex="cannot be parsed to int32 tensor"):
sampler([cirq.Circuit()], repetitions=[10])

@parameterized.parameters([{
'backend': None
}, {
'backend': cirq.Simulator()
}, {
'backend': cirq.DensityMatrixSimulator()
}])
@parameterized.parameters([
{
'backend': 'noiseless'
},
{
'backend': 'noisy'
},
{
'backend': cirq.Simulator()
},
{
'backend': None # old API usage.
}
])
def test_sample_invalid_combinations(self, backend):
"""Test with valid type inputs and valid value, but incorrect combo."""
sampler = sample.Sample(backend)
Expand Down Expand Up @@ -152,11 +159,10 @@ def test_sample_outputs_simple(self):
@parameterized.parameters(
list(
util.kwargs_cartesian_product(
backend=[None,
cirq.Simulator(),
cirq.DensityMatrixSimulator()],
all_n_qubits=[[3], [8], [3, 4], [3, 4, 10]],
n_samples=[1, 10, 100],
backend=['noiseless', 'noisy',
cirq.Simulator(), None],
all_n_qubits=[[3, 4, 10]],
n_samples=[1],
symbol_names=[[], ['a', 'b']])))
def test_sample_output(self, backend, all_n_qubits, n_samples,
symbol_names):
Expand All @@ -178,6 +184,7 @@ def test_sample_output(self, backend, all_n_qubits, n_samples,
symbol_names=symbol_names,
symbol_values=symbol_values,
repetitions=n_samples).to_list()

self.assertEqual(expected_outputs, layer_output)


Expand Down

0 comments on commit ca4d0cb

Please sign in to comment.