Skip to content

Commit

Permalink
Remove parallelization in batch sample. (tensorflow#566)
Browse files Browse the repository at this point in the history
* Remove parallelization in batch sample.

* Removed dead code.

* A feedback.
  • Loading branch information
MichaelBroughton authored May 11, 2021
1 parent d7dd484 commit 3036034
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 166 deletions.
185 changes: 23 additions & 162 deletions tensorflow_quantum/core/ops/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
# ==============================================================================
"""A module to for running Cirq objects."""
import collections
import os

import multiprocessing as mp
from multiprocessing.pool import Pool as ProcessPool
import numpy as np
import cirq

Expand Down Expand Up @@ -118,129 +115,6 @@ def _fixed_circuit_plus_pauli_string_measurements(circuit, pauli_string):
return circuit


def _make_complex_view(shape, init_val):
"""Build a RawArray that will map to the real and imaginary parts of a
complex number."""
shape = list(shape)
shape[-1] *= 2
data = np.ones(shape, dtype=np.float32) * init_val

flattened_size = 1
for dim_size in shape:
flattened_size *= dim_size
shared_mem_array = mp.RawArray('f', flattened_size)
np_view = np.frombuffer(shared_mem_array, dtype=np.float32).reshape(shape)
np.copyto(np_view, data)
return shared_mem_array


def _convert_complex_view_to_np(view, shape):
"""Get a numpy view ontop of the rawarray view. Small overhead."""
shape = list(shape)
shape[-1] *= 2
return np.frombuffer(view, dtype=np.float32).reshape(shape)


def _update_complex_np(np_view, i, to_add):
"""Update the shared memory undernath the numpy view.
to_add is passed by reference since we don't do much with it."""
np_view[i, ...] = np.pad(to_add,
(0, (np_view.shape[-1] // 2 - to_add.shape[-1])),
'constant',
constant_values=-2).view(np.float32)


def _convert_complex_view_to_result(view, shape):
"""Convert a rawarray view to a numpy array and reindex so that
the underlying pair of double arrays are squished together to make a
complex array of half the underlying size."""
shape = list(shape)
shape[-1] *= 2
np_view = np.frombuffer(view, dtype=np.float32).reshape(shape)

# The below view will cause a re-interpretation of underlying
# memory so use sparingly.
return np_view.view(np.complex64)


def _make_simple_view(shape, init_val, dtype, c_code):
"""Make a shared memory view for floating type."""
data = np.ones(shape, dtype=dtype) * init_val
flattened_size = 1
for dim_size in shape:
flattened_size *= dim_size
shared_mem_array = mp.RawArray(c_code, flattened_size)
np_view = np.frombuffer(shared_mem_array, dtype=dtype).reshape(shape)
np.copyto(np_view, data)
return shared_mem_array


def _convert_simple_view_to_np(view, dtype, shape):
"""Create a numpy view to a float array, low overhead."""
return np.frombuffer(view, dtype=dtype).reshape(shape)


def _batch_update_simple_np(np_view, i, to_add):
"""Update the shared memory underneath the numpy view.
to_add is again passed by reference."""
np_view[i, ...] = to_add


def _pointwise_update_simple_np(np_view, i, j, to_add):
"""Do a batch and sub-batch index update to numpy view."""
np_view[i, j, ...] = to_add


def _convert_simple_view_to_result(view, dtype, shape):
"""Convert a RawArray view to final numpy array."""
return np.frombuffer(view, dtype=dtype).reshape(shape)


def _prep_pool_input_args(indices, *args, slice_args=True):
"""Break down a set of indices, and optional args into a generator
of length cpu_count."""
block_size = int(np.ceil(len(indices) / os.cpu_count()))
for i in range(0, len(indices), block_size):
if slice_args:
yield tuple([indices[i:i + block_size]] +
[x[i:i + block_size] for x in args])
else:
yield tuple([indices[i:i + block_size]] + [x for x in args])


# process are separate from all the other processes,
# so INFO_DICTs will not step on each other.
INFO_DICT = {}


def _setup_dict(array_view, view_shape, simulator, post_process):
INFO_DICT['arr'] = array_view
INFO_DICT['shape'] = view_shape
INFO_DICT['sim'] = simulator
INFO_DICT['post_process'] = post_process


def _sample_worker_func(indices, programs, params, n_samples):
"""Sample n_samples from progams[i] with params[i] placed in it."""
x_np = _convert_simple_view_to_np(INFO_DICT['arr'], np.int32,
INFO_DICT['shape'])
simulator = INFO_DICT['sim']

for i, index in enumerate(indices):
qubits = sorted(programs[i].all_qubits())
# (#679) Just ignore empty programs.
if len(qubits) == 0:
continue
state = simulator.simulate(programs[i], params[i])
samples = INFO_DICT['post_process'](state, len(qubits),
n_samples[i]).astype(np.int32)
_batch_update_simple_np(
x_np, index,
np.pad(samples, ((0, 0), (x_np.shape[2] - len(qubits), 0)),
'constant',
constant_values=-2))


def _validate_inputs(circuits, param_resolvers, simulator, sim_type):
"""Type check and sanity check inputs."""
if not isinstance(circuits, (list, tuple, np.ndarray)):
Expand Down Expand Up @@ -324,24 +198,24 @@ def batch_calculate_state(circuits, param_resolvers, simulator):
return empty_ret

biggest_circuit = max(len(circuit.all_qubits()) for circuit in circuits)

# Default to state vector unless we see densitymatrix.
return_mem_shape = (len(circuits), 1 << biggest_circuit)
post_process = lambda x: x.final_state_vector
if isinstance(simulator, cirq.DensityMatrixSimulator):
return_mem_shape = (len(circuits), 1 << biggest_circuit,
1 << biggest_circuit)
post_process = lambda x: x.final_density_matrix
# Assumes anything else returns a state vector.
else:
return_mem_shape = (len(circuits), 1 << biggest_circuit)
post_process = lambda x: x.final_state_vector

shared_array = _make_complex_view(return_mem_shape, -2)

x_np = _convert_complex_view_to_np(shared_array, return_mem_shape)
batch_states = np.ones(return_mem_shape, dtype=np.complex64) * -2
for index, (program, param) in enumerate(zip(circuits, param_resolvers)):
result = simulator.simulate(program, param)
final_array = post_process(result).astype(np.complex64)
_update_complex_np(x_np, index, final_array)
state_size = 1 << len(program.all_qubits())
state = post_process(result).astype(np.complex64)
sub_index = (slice(None, state_size, 1),) * (batch_states.ndim - 1)
batch_states[index][sub_index] = state

return _convert_complex_view_to_result(shared_array, return_mem_shape)
return batch_states


def batch_calculate_expectation(circuits, param_resolvers, ops, simulator):
Expand Down Expand Up @@ -495,7 +369,7 @@ def batch_calculate_sampled_expectation(circuits, param_resolvers, ops,


def batch_sample(circuits, param_resolvers, n_samples, simulator):
"""Sample from circuits using parallel processing.
"""Sample from circuits.
Returns a `np.ndarray` containing n_samples samples from all the circuits in
circuits given that the corresponding `cirq.ParamResolver` in
Expand Down Expand Up @@ -527,7 +401,7 @@ def batch_sample(circuits, param_resolvers, n_samples, simulator):
"""
_validate_inputs(circuits, param_resolvers, simulator, 'sample')
if _check_empty(circuits):
return np.zeros((0, 0, 0), dtype=np.int32)
return np.zeros((0, 0, 0), dtype=np.int8)

if not isinstance(n_samples, int):
raise TypeError('n_samples must be an int.'
Expand All @@ -538,30 +412,17 @@ def batch_sample(circuits, param_resolvers, n_samples, simulator):

biggest_circuit = max(len(circuit.all_qubits()) for circuit in circuits)
return_mem_shape = (len(circuits), n_samples, biggest_circuit)
shared_array = _make_simple_view(return_mem_shape, -2, np.int32, 'i')

if isinstance(simulator, cirq.DensityMatrixSimulator):
post_process = lambda state, size, n_samples: \
cirq.sample_density_matrix(
state.final_density_matrix, [i for i in range(size)],
repetitions=n_samples)
elif isinstance(simulator, cirq.Simulator):
post_process = lambda state, size, n_samples: cirq.sample_state_vector(
state.final_state_vector, list(range(size)), repetitions=n_samples)
else:
raise TypeError('Simulator {} is not supported by batch_sample.'.format(
type(simulator)))

input_args = list(
_prep_pool_input_args(range(len(circuits)), circuits, param_resolvers,
[n_samples] * len(circuits)))
return_array = np.ones(return_mem_shape, dtype=np.int8) * -2

with ProcessPool(processes=None,
initializer=_setup_dict,
initargs=(shared_array, return_mem_shape, simulator,
post_process)) as pool:
for batch, (c, resolver) in enumerate(zip(circuits, param_resolvers)):
if len(c.all_qubits()) == 0:
continue

pool.starmap(_sample_worker_func, input_args)
qb_keys = [(q, str(i)) for i, q in enumerate(sorted(c.all_qubits()))]
c_m = c + cirq.Circuit(cirq.measure(q, key=i) for q, i in qb_keys)
run_c = cirq.resolve_parameters(c_m, resolver)
bits = simulator.sample(run_c, repetitions=n_samples)
flat_m = bits[[x[1] for x in qb_keys]].to_numpy().astype(np.int8)
return_array[batch, :, biggest_circuit - len(qb_keys):] = flat_m

return _convert_simple_view_to_result(shared_array, np.int32,
return_mem_shape)
return return_array
8 changes: 4 additions & 4 deletions tensorflow_quantum/core/ops/batch_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_batch_sample_basic(self, sim):
expected_results = _sample_helper(sim, state, len(qubits), n_samples)

self.assertAllEqual(expected_results, test_results[0])
self.assertDTypeEqual(test_results, np.int32)
self.assertDTypeEqual(test_results, np.int8)

@parameterized.parameters([{
'sim': cirq.DensityMatrixSimulator()
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_batch_sample(self, sim):
for a, b in zip(tfq_histograms, cirq_histograms):
self.assertLess(stats.entropy(a + 1e-8, b + 1e-8), 0.005)

self.assertDTypeEqual(results, np.int32)
self.assertDTypeEqual(results, np.int8)

@parameterized.parameters([{
'sim': cirq.DensityMatrixSimulator()
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_empty_circuits(self, sim):
r = _sample_helper(sim, state, len(circuit.all_qubits()), n_samples)
self.assertAllClose(r, a, atol=1e-5)

self.assertDTypeEqual(results, np.int32)
self.assertDTypeEqual(results, np.int8)

@parameterized.parameters([{
'sim': cirq.DensityMatrixSimulator()
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_no_circuit(self, sim):

# (4) Test sampling
results = batch_util.batch_sample([], [], [], sim)
self.assertDTypeEqual(results, np.int32)
self.assertDTypeEqual(results, np.int8)
self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape)


Expand Down

0 comments on commit 3036034

Please sign in to comment.