diff --git a/tensorflow_quantum/core/ops/batch_util.py b/tensorflow_quantum/core/ops/batch_util.py index e95dd7f5d..75f148424 100644 --- a/tensorflow_quantum/core/ops/batch_util.py +++ b/tensorflow_quantum/core/ops/batch_util.py @@ -14,10 +14,7 @@ # ============================================================================== """A module to for running Cirq objects.""" import collections -import os -import multiprocessing as mp -from multiprocessing.pool import Pool as ProcessPool import numpy as np import cirq @@ -118,129 +115,6 @@ def _fixed_circuit_plus_pauli_string_measurements(circuit, pauli_string): return circuit -def _make_complex_view(shape, init_val): - """Build a RawArray that will map to the real and imaginary parts of a - complex number.""" - shape = list(shape) - shape[-1] *= 2 - data = np.ones(shape, dtype=np.float32) * init_val - - flattened_size = 1 - for dim_size in shape: - flattened_size *= dim_size - shared_mem_array = mp.RawArray('f', flattened_size) - np_view = np.frombuffer(shared_mem_array, dtype=np.float32).reshape(shape) - np.copyto(np_view, data) - return shared_mem_array - - -def _convert_complex_view_to_np(view, shape): - """Get a numpy view ontop of the rawarray view. Small overhead.""" - shape = list(shape) - shape[-1] *= 2 - return np.frombuffer(view, dtype=np.float32).reshape(shape) - - -def _update_complex_np(np_view, i, to_add): - """Update the shared memory undernath the numpy view. - to_add is passed by reference since we don't do much with it.""" - np_view[i, ...] = np.pad(to_add, - (0, (np_view.shape[-1] // 2 - to_add.shape[-1])), - 'constant', - constant_values=-2).view(np.float32) - - -def _convert_complex_view_to_result(view, shape): - """Convert a rawarray view to a numpy array and reindex so that - the underlying pair of double arrays are squished together to make a - complex array of half the underlying size.""" - shape = list(shape) - shape[-1] *= 2 - np_view = np.frombuffer(view, dtype=np.float32).reshape(shape) - - # The below view will cause a re-interpretation of underlying - # memory so use sparingly. - return np_view.view(np.complex64) - - -def _make_simple_view(shape, init_val, dtype, c_code): - """Make a shared memory view for floating type.""" - data = np.ones(shape, dtype=dtype) * init_val - flattened_size = 1 - for dim_size in shape: - flattened_size *= dim_size - shared_mem_array = mp.RawArray(c_code, flattened_size) - np_view = np.frombuffer(shared_mem_array, dtype=dtype).reshape(shape) - np.copyto(np_view, data) - return shared_mem_array - - -def _convert_simple_view_to_np(view, dtype, shape): - """Create a numpy view to a float array, low overhead.""" - return np.frombuffer(view, dtype=dtype).reshape(shape) - - -def _batch_update_simple_np(np_view, i, to_add): - """Update the shared memory underneath the numpy view. - to_add is again passed by reference.""" - np_view[i, ...] = to_add - - -def _pointwise_update_simple_np(np_view, i, j, to_add): - """Do a batch and sub-batch index update to numpy view.""" - np_view[i, j, ...] = to_add - - -def _convert_simple_view_to_result(view, dtype, shape): - """Convert a RawArray view to final numpy array.""" - return np.frombuffer(view, dtype=dtype).reshape(shape) - - -def _prep_pool_input_args(indices, *args, slice_args=True): - """Break down a set of indices, and optional args into a generator - of length cpu_count.""" - block_size = int(np.ceil(len(indices) / os.cpu_count())) - for i in range(0, len(indices), block_size): - if slice_args: - yield tuple([indices[i:i + block_size]] + - [x[i:i + block_size] for x in args]) - else: - yield tuple([indices[i:i + block_size]] + [x for x in args]) - - -# process are separate from all the other processes, -# so INFO_DICTs will not step on each other. -INFO_DICT = {} - - -def _setup_dict(array_view, view_shape, simulator, post_process): - INFO_DICT['arr'] = array_view - INFO_DICT['shape'] = view_shape - INFO_DICT['sim'] = simulator - INFO_DICT['post_process'] = post_process - - -def _sample_worker_func(indices, programs, params, n_samples): - """Sample n_samples from progams[i] with params[i] placed in it.""" - x_np = _convert_simple_view_to_np(INFO_DICT['arr'], np.int32, - INFO_DICT['shape']) - simulator = INFO_DICT['sim'] - - for i, index in enumerate(indices): - qubits = sorted(programs[i].all_qubits()) - # (#679) Just ignore empty programs. - if len(qubits) == 0: - continue - state = simulator.simulate(programs[i], params[i]) - samples = INFO_DICT['post_process'](state, len(qubits), - n_samples[i]).astype(np.int32) - _batch_update_simple_np( - x_np, index, - np.pad(samples, ((0, 0), (x_np.shape[2] - len(qubits), 0)), - 'constant', - constant_values=-2)) - - def _validate_inputs(circuits, param_resolvers, simulator, sim_type): """Type check and sanity check inputs.""" if not isinstance(circuits, (list, tuple, np.ndarray)): @@ -324,24 +198,24 @@ def batch_calculate_state(circuits, param_resolvers, simulator): return empty_ret biggest_circuit = max(len(circuit.all_qubits()) for circuit in circuits) + + # Default to state vector unless we see densitymatrix. + return_mem_shape = (len(circuits), 1 << biggest_circuit) + post_process = lambda x: x.final_state_vector if isinstance(simulator, cirq.DensityMatrixSimulator): return_mem_shape = (len(circuits), 1 << biggest_circuit, 1 << biggest_circuit) post_process = lambda x: x.final_density_matrix - # Assumes anything else returns a state vector. - else: - return_mem_shape = (len(circuits), 1 << biggest_circuit) - post_process = lambda x: x.final_state_vector - shared_array = _make_complex_view(return_mem_shape, -2) - - x_np = _convert_complex_view_to_np(shared_array, return_mem_shape) + batch_states = np.ones(return_mem_shape, dtype=np.complex64) * -2 for index, (program, param) in enumerate(zip(circuits, param_resolvers)): result = simulator.simulate(program, param) - final_array = post_process(result).astype(np.complex64) - _update_complex_np(x_np, index, final_array) + state_size = 1 << len(program.all_qubits()) + state = post_process(result).astype(np.complex64) + sub_index = (slice(None, state_size, 1),) * (batch_states.ndim - 1) + batch_states[index][sub_index] = state - return _convert_complex_view_to_result(shared_array, return_mem_shape) + return batch_states def batch_calculate_expectation(circuits, param_resolvers, ops, simulator): @@ -495,7 +369,7 @@ def batch_calculate_sampled_expectation(circuits, param_resolvers, ops, def batch_sample(circuits, param_resolvers, n_samples, simulator): - """Sample from circuits using parallel processing. + """Sample from circuits. Returns a `np.ndarray` containing n_samples samples from all the circuits in circuits given that the corresponding `cirq.ParamResolver` in @@ -527,7 +401,7 @@ def batch_sample(circuits, param_resolvers, n_samples, simulator): """ _validate_inputs(circuits, param_resolvers, simulator, 'sample') if _check_empty(circuits): - return np.zeros((0, 0, 0), dtype=np.int32) + return np.zeros((0, 0, 0), dtype=np.int8) if not isinstance(n_samples, int): raise TypeError('n_samples must be an int.' @@ -538,30 +412,17 @@ def batch_sample(circuits, param_resolvers, n_samples, simulator): biggest_circuit = max(len(circuit.all_qubits()) for circuit in circuits) return_mem_shape = (len(circuits), n_samples, biggest_circuit) - shared_array = _make_simple_view(return_mem_shape, -2, np.int32, 'i') - - if isinstance(simulator, cirq.DensityMatrixSimulator): - post_process = lambda state, size, n_samples: \ - cirq.sample_density_matrix( - state.final_density_matrix, [i for i in range(size)], - repetitions=n_samples) - elif isinstance(simulator, cirq.Simulator): - post_process = lambda state, size, n_samples: cirq.sample_state_vector( - state.final_state_vector, list(range(size)), repetitions=n_samples) - else: - raise TypeError('Simulator {} is not supported by batch_sample.'.format( - type(simulator))) - - input_args = list( - _prep_pool_input_args(range(len(circuits)), circuits, param_resolvers, - [n_samples] * len(circuits))) + return_array = np.ones(return_mem_shape, dtype=np.int8) * -2 - with ProcessPool(processes=None, - initializer=_setup_dict, - initargs=(shared_array, return_mem_shape, simulator, - post_process)) as pool: + for batch, (c, resolver) in enumerate(zip(circuits, param_resolvers)): + if len(c.all_qubits()) == 0: + continue - pool.starmap(_sample_worker_func, input_args) + qb_keys = [(q, str(i)) for i, q in enumerate(sorted(c.all_qubits()))] + c_m = c + cirq.Circuit(cirq.measure(q, key=i) for q, i in qb_keys) + run_c = cirq.resolve_parameters(c_m, resolver) + bits = simulator.sample(run_c, repetitions=n_samples) + flat_m = bits[[x[1] for x in qb_keys]].to_numpy().astype(np.int8) + return_array[batch, :, biggest_circuit - len(qb_keys):] = flat_m - return _convert_simple_view_to_result(shared_array, np.int32, - return_mem_shape) + return return_array diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index 9c6aa995d..3352b5ba2 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -174,7 +174,7 @@ def test_batch_sample_basic(self, sim): expected_results = _sample_helper(sim, state, len(qubits), n_samples) self.assertAllEqual(expected_results, test_results[0]) - self.assertDTypeEqual(test_results, np.int32) + self.assertDTypeEqual(test_results, np.int8) @parameterized.parameters([{ 'sim': cirq.DensityMatrixSimulator() @@ -210,7 +210,7 @@ def test_batch_sample(self, sim): for a, b in zip(tfq_histograms, cirq_histograms): self.assertLess(stats.entropy(a + 1e-8, b + 1e-8), 0.005) - self.assertDTypeEqual(results, np.int32) + self.assertDTypeEqual(results, np.int8) @parameterized.parameters([{ 'sim': cirq.DensityMatrixSimulator() @@ -267,7 +267,7 @@ def test_empty_circuits(self, sim): r = _sample_helper(sim, state, len(circuit.all_qubits()), n_samples) self.assertAllClose(r, a, atol=1e-5) - self.assertDTypeEqual(results, np.int32) + self.assertDTypeEqual(results, np.int8) @parameterized.parameters([{ 'sim': cirq.DensityMatrixSimulator() @@ -297,7 +297,7 @@ def test_no_circuit(self, sim): # (4) Test sampling results = batch_util.batch_sample([], [], [], sim) - self.assertDTypeEqual(results, np.int32) + self.assertDTypeEqual(results, np.int8) self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape)