Skip to content

Commit

Permalink
Merge pull request tensorflow#309 from tensorflow/s_expectation_parallel
Browse files Browse the repository at this point in the history
Parallel upgrade for sampled expectation.
  • Loading branch information
jaeyoo authored Jul 17, 2020
2 parents a068f2a + b74370e commit ec7077d
Showing 1 changed file with 117 additions and 17 deletions.
134 changes: 117 additions & 17 deletions tensorflow_quantum/core/ops/tfq_simulate_sampled_expectation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "../qsim/lib/circuit.h"
#include "../qsim/lib/gate_appl.h"
#include "../qsim/lib/gates_cirq.h"
#include "../qsim/lib/seqfor.h"
#include "../qsim/lib/simmux.h"
#include "cirq/google/api/v2/program.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand All @@ -43,10 +44,10 @@ typedef qsim::Circuit<QsimGate> QsimCircuit;
class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
public:
explicit TfqSimulateSampledExpectationOp(
tensorflow::OpKernelConstruction *context)
tensorflow::OpKernelConstruction* context)
: OpKernel(context) {}

void Compute(tensorflow::OpKernelContext *context) override {
void Compute(tensorflow::OpKernelContext* context) override {
// TODO (mbbrough): add more dimension checks for other inputs here.
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 5,
Expand All @@ -60,7 +61,7 @@ class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
output_shape.AddDim(output_dim_batch_size);
output_shape.AddDim(output_dim_op_size);

tensorflow::Tensor *output = nullptr;
tensorflow::Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->matrix<float>();

Expand Down Expand Up @@ -112,9 +113,45 @@ class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
programs.size(), num_cycles, construct_f);

int max_num_qubits = 0;
for (const int num : num_qubits) {
max_num_qubits = std::max(max_num_qubits, num);
}

// Cross reference with standard google cloud compute instances
// Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits)
// e2s2 = 2 CPU, 8GB -> Can safely do 25 since Memory = 4GB
// e2s4 = 4 CPU, 16GB -> Can safely do 25 since Memory = 8GB
// ...
if (max_num_qubits < 26) {
ComputeSmall(num_qubits, max_num_qubits, fused_circuits, pauli_sums,
num_samples, context, &output_tensor);
} else {
ComputeLarge(num_qubits, fused_circuits, pauli_sums, num_samples, context,
&output_tensor);
}

// just to be on the safe side.
qsim_circuits.clear();
fused_circuits.clear();
num_qubits.clear();
maps.clear();
pauli_sums.clear();
num_samples.clear();
programs.clear();
}

private:
void ComputeLarge(
const std::vector<int>& num_qubits,
const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
const std::vector<std::vector<PauliSum>>& pauli_sums,
const std::vector<std::vector<int>>& num_samples,
tensorflow::OpKernelContext* context,
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
// Instantiate qsim objects.
const auto tfq_for = tfq::QsimFor(context);
using Simulator = qsim::Simulator<const tfq::QsimFor &>;
using Simulator = qsim::Simulator<const tfq::QsimFor&>;
using StateSpace = Simulator::StateSpace;
using State = StateSpace::State;

Expand All @@ -126,7 +163,7 @@ class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
// Simulate programs one by one. Parallelizing over wavefunctions
// we no longer parallelize over circuits. Each time we encounter a
// a larger circuit we will grow the Statevector as necessary.
for (int i = 0; i < programs.size(); i++) {
for (int i = 0; i < fused_circuits.size(); i++) {
int nq = num_qubits[i];
Simulator sim = Simulator(nq, tfq_for);
StateSpace ss = StateSpace(nq, tfq_for);
Expand All @@ -145,27 +182,90 @@ class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
}
for (int j = 0; j < pauli_sums[i].size(); j++) {
// (#679) Just ignore empty program
if (programs[i].circuit().moments().empty()) {
output_tensor(i, j) = -2.0;
if (fused_circuits[i].size() == 0) {
(*output_tensor)(i, j) = -2.0;
continue;
}
float exp_v = 0.0;
OP_REQUIRES_OK(context, ComputeSampledExpectationQsim(
pauli_sums[i][j], sim, ss, sv, scratch,
num_samples[i][j], &exp_v));
output_tensor(i, j) = exp_v;
(*output_tensor)(i, j) = exp_v;
}
}
// just to be on the safe side.
sv.release();
scratch.release();
qsim_circuits.clear();
fused_circuits.clear();
num_qubits.clear();
maps.clear();
pauli_sums.clear();
num_samples.clear();
programs.clear();
}

void ComputeSmall(
const std::vector<int>& num_qubits, const int max_num_qubits,
const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
const std::vector<std::vector<PauliSum>>& pauli_sums,
const std::vector<std::vector<int>>& num_samples,
tensorflow::OpKernelContext* context,
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
const auto tfq_for = qsim::SequentialFor(1);
using Simulator = qsim::Simulator<const qsim::SequentialFor&>;
using StateSpace = Simulator::StateSpace;
using State = StateSpace::State;

const int output_dim_op_size = output_tensor->dimension(1);

auto DoWork = [&](int start, int end) {
int old_batch_index = -2;
int cur_batch_index = -1;
int largest_nq = 1;
int cur_op_index;

State sv = StateSpace(largest_nq, tfq_for).CreateState();
State scratch = StateSpace(largest_nq, tfq_for).CreateState();
for (int i = start; i < end; i++) {
cur_batch_index = i / output_dim_op_size;
cur_op_index = i % output_dim_op_size;

const int nq = num_qubits[cur_batch_index];
Simulator sim = Simulator(nq, tfq_for);
StateSpace ss = StateSpace(nq, tfq_for);

// (#679) Just ignore empty program
if (fused_circuits[cur_batch_index].size() == 0) {
(*output_tensor)(cur_batch_index, cur_op_index) = -2.0;
continue;
}

if (cur_batch_index != old_batch_index) {
// We've run into a new wavefunction we must compute.
// Only compute a new wavefunction when we have to.
if (nq >= largest_nq) {
sv = ss.CreateState();
scratch = ss.CreateState();
largest_nq = nq;
}
// no need to update scratch_state since ComputeExpectation
// will take care of things for us.
ss.SetStateZero(sv);
for (int j = 0; j < fused_circuits[cur_batch_index].size(); j++) {
qsim::ApplyFusedGate(sim, fused_circuits[cur_batch_index][j], sv);
}
}

float exp_v = 0.0;
OP_REQUIRES_OK(
context,
ComputeSampledExpectationQsim(
pauli_sums[cur_batch_index][cur_op_index], sim, ss, sv, scratch,
num_samples[cur_batch_index][cur_op_index], &exp_v));
(*output_tensor)(cur_batch_index, cur_op_index) = exp_v;
old_batch_index = cur_batch_index;
}
sv.release();
scratch.release();
};

const int64_t num_cycles =
200 * (int64_t(1) << static_cast<int64_t>(max_num_qubits));
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
fused_circuits.size() * output_dim_op_size, num_cycles, DoWork);
}
};

Expand All @@ -180,7 +280,7 @@ REGISTER_OP("TfqSimulateSampledExpectation")
.Input("pauli_sums: string")
.Input("num_samples: int32")
.Output("expectations: float")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *c) {
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
tensorflow::shape_inference::ShapeHandle programs_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &programs_shape));

Expand Down

0 comments on commit ec7077d

Please sign in to comment.