Skip to content

Commit

Permalink
fixed the retraining bug. I had it fixed in my original testing repo …
Browse files Browse the repository at this point in the history
…but not the public one. Eventually I can allow for reloading of symengine expressions to speed up retraining slightly instead of remaking the symengine expression all over again for each run.
  • Loading branch information
wtroy2 committed Jul 31, 2024
1 parent 08b7ec1 commit 7fb9eb8
Showing 1 changed file with 123 additions and 173 deletions.
296 changes: 123 additions & 173 deletions cpp/quantum_kan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,37 +1059,11 @@ void save_data(
void save_data_2_layer(
const RCP<const Basic>& symbolic_sum_no_mean,
int x_data_size,
const RCP<const Basic>& aux_all_sub_expressions_equation,
const unordered_map<RCP<const Basic>, RCP<const Basic>>& aux_dict_final,
const vector<vector<RCP<const Basic>>>& coefficients_plus1,
const vector<vector<RCP<const Basic>>>& coefficients_plus2,
const vector<vector<RCP<const Basic>>>& coefficients_plus3,
const string& filename
) {
json j;
j["symbolic_sum_no_mean"] = symbolic_sum_no_mean->__str__();
j["x_data_size"] = x_data_size;
j["aux_all_sub_expressions_equation"] = aux_all_sub_expressions_equation->__str__();

json aux_dict_json;
for (const auto& pair : aux_dict_final) {
aux_dict_json[pair.first->__str__()] = pair.second->__str__();
}
j["aux_dict_final"] = aux_dict_json;

auto convert_coefficients_to_json = [](const vector<vector<RCP<const Basic>>>& coefficients) {
json coeffs_json;
for (size_t i = 0; i < coefficients.size(); ++i) {
for (size_t j = 0; j < coefficients[i].size(); ++j) {
coeffs_json[to_string(i)][to_string(j)] = coefficients[i][j]->__str__();
}
}
return coeffs_json;
};

j["coefficients_plus1"] = convert_coefficients_to_json(coefficients_plus1);
j["coefficients_plus2"] = convert_coefficients_to_json(coefficients_plus2);
j["coefficients_plus3"] = convert_coefficients_to_json(coefficients_plus3);

save_json_to_file(j, filename);
}
Expand Down Expand Up @@ -1136,38 +1110,12 @@ void load_data(
void load_data_2_layer(
RCP<const Basic>& symbolic_sum_no_mean,
int& x_data_size,
RCP<const Basic>& aux_all_sub_expressions_equation,
unordered_map<RCP<const Basic>, RCP<const Basic>>& aux_dict_final,
vector<vector<RCP<const Basic>>>& coefficients_plus1,
vector<vector<RCP<const Basic>>>& coefficients_plus2,
vector<vector<RCP<const Basic>>>& coefficients_plus3,
const string& filename
) {
json j = load_json_from_file(filename);

symbolic_sum_no_mean = SymEngine::parse(j["symbolic_sum_no_mean"].get<string>());
x_data_size = j["x_data_size"].get<int>();
aux_all_sub_expressions_equation = SymEngine::parse(j["aux_all_sub_expressions_equation"].get<string>());

for (const auto& pair : j["aux_dict_final"].items()) {
aux_dict_final[SymEngine::parse(pair.key())] = SymEngine::parse(pair.value().get<string>());
}

auto convert_json_to_coefficients = [](const json& coeffs_json) {
vector<vector<RCP<const Basic>>> coefficients;
for (const auto& row : coeffs_json.items()) {
vector<RCP<const Basic>> coeff_row;
for (const auto& coeff : row.value().items()) {
coeff_row.push_back(SymEngine::parse(coeff.value().get<string>()));
}
coefficients.push_back(coeff_row);
}
return coefficients;
};

coefficients_plus1 = convert_json_to_coefficients(j["coefficients_plus1"]);
coefficients_plus2 = convert_json_to_coefficients(j["coefficients_plus2"]);
coefficients_plus3 = convert_json_to_coefficients(j["coefficients_plus3"]);
}

// Function to precompute powers and apply auxiliary variables
Expand Down Expand Up @@ -1531,167 +1479,169 @@ compute_mse_with_penalty(int d1, int d2, int d3, int m1, int m2, int m3, double
int x_data_size_old;

if (!load_filename.empty()) {
load_data_2_layer(preloaded_symbolic_sum, x_data_size_old, aux_all_sub_expressions_equation, aux_dict_final, coefficients_plus1, coefficients_plus2, coefficients_plus3, load_filename);
load_data_2_layer(preloaded_symbolic_sum, x_data_size_old, load_filename);
x_data_size = x_data_size + x_data_size_old;
} else {
// Define symbolic binary variables for the coefficients
coefficients_plus1.resize(degree1 + 1, vector<RCP<const Basic>>(m1));
coefficients_plus2.resize(degree2 + 1, vector<RCP<const Basic>>(m2));
coefficients_plus3.resize(degree3 + 1, vector<RCP<const Basic>>(m3));
// initialize_coefficients();

for (int i = 0; i <= degree1; ++i) {
for (int j = 0; j < m1; ++j) {
coefficients_plus1[i][j] = binary("P1_" + to_string(i) + "_plus_" + to_string(j));
}
}
// Note: Fix the preloading of the symengine equation later. Commented out for now.
// } else {
// Define symbolic binary variables for the coefficients
coefficients_plus1.resize(degree1 + 1, vector<RCP<const Basic>>(m1));
coefficients_plus2.resize(degree2 + 1, vector<RCP<const Basic>>(m2));
coefficients_plus3.resize(degree3 + 1, vector<RCP<const Basic>>(m3));
// initialize_coefficients();

for (int i = 0; i <= degree1; ++i) {
for (int j = 0; j < m1; ++j) {
coefficients_plus1[i][j] = binary("P1_" + to_string(i) + "_plus_" + to_string(j));
}
}

for (int i = 0; i <= degree2; ++i) {
for (int j = 0; j < m2; ++j) {
coefficients_plus2[i][j] = binary("P2_" + to_string(i) + "_plus_" + to_string(j));
}
for (int i = 0; i <= degree2; ++i) {
for (int j = 0; j < m2; ++j) {
coefficients_plus2[i][j] = binary("P2_" + to_string(i) + "_plus_" + to_string(j));
}
}

for (int i = 0; i <= degree3; ++i) {
for (int j = 0; j < m3; ++j) {
coefficients_plus3[i][j] = binary("P3_" + to_string(i) + "_plus_" + to_string(j));
}
for (int i = 0; i <= degree3; ++i) {
for (int j = 0; j < m3; ++j) {
coefficients_plus3[i][j] = binary("P3_" + to_string(i) + "_plus_" + to_string(j));
}
// Define control points
vector<RCP<const Basic>> coefficients_A;
vector<RCP<const Basic>> coefficients_B;
vector<RCP<const Basic>> coefficients_C;
}
// Define control points
vector<RCP<const Basic>> coefficients_A;
vector<RCP<const Basic>> coefficients_B;
vector<RCP<const Basic>> coefficients_C;

for (int i = 0; i <= degree1; ++i) {
coefficients_A.push_back(symbol("A" + std::to_string(i)));
}
for (int i = 0; i <= degree1; ++i) {
coefficients_A.push_back(symbol("A" + std::to_string(i)));
}

for (int i = 0; i <= degree2; ++i) {
coefficients_B.push_back(symbol("B" + std::to_string(i)));
}
for (int i = 0; i <= degree2; ++i) {
coefficients_B.push_back(symbol("B" + std::to_string(i)));
}

for (int i = 0; i <= degree3; ++i) {
coefficients_C.push_back(symbol("C" + std::to_string(i)));
}
for (int i = 0; i <= degree3; ++i) {
coefficients_C.push_back(symbol("C" + std::to_string(i)));
}

// Generate coefficient expressions
vector<RCP<const Basic>> coeff_expressions1(degree1 + 1);
vector<RCP<const Basic>> coeff_expressions2(degree2 + 1);
vector<RCP<const Basic>> coeff_expressions3(degree3 + 1);
// Generate coefficient expressions
vector<RCP<const Basic>> coeff_expressions1(degree1 + 1);
vector<RCP<const Basic>> coeff_expressions2(degree2 + 1);
vector<RCP<const Basic>> coeff_expressions3(degree3 + 1);

for (int i = 0; i <= degree1; ++i) {
coeff_expressions1[i] = generate_coefficient_expr(coefficients_plus1, degree1, m1, i);
}
for (int i = 0; i <= degree1; ++i) {
coeff_expressions1[i] = generate_coefficient_expr(coefficients_plus1, degree1, m1, i);
}

for (int i = 0; i <= degree2; ++i) {
coeff_expressions2[i] = generate_coefficient_expr(coefficients_plus2, degree2, m2, i);
}
for (int i = 0; i <= degree2; ++i) {
coeff_expressions2[i] = generate_coefficient_expr(coefficients_plus2, degree2, m2, i);
}

for (int i = 0; i <= degree3; ++i) {
coeff_expressions3[i] = generate_coefficient_expr(coefficients_plus3, degree3, m3, i);
}
for (int i = 0; i <= degree3; ++i) {
coeff_expressions3[i] = generate_coefficient_expr(coefficients_plus3, degree3, m3, i);
}

// Create an empty aux_dict
unordered_map<RCP<const Basic>, RCP<const Basic>> existing_aux_dict_precomputed_powers;
// Create an empty aux_dict
unordered_map<RCP<const Basic>, RCP<const Basic>> existing_aux_dict_precomputed_powers;

// Precompute powers
unordered_map<string, unordered_map<int, RCP<const Basic>>> precomputed_powers;
for (int i = 0; i <= degree1; ++i) {
precomputed_powers["A" + to_string(i)] = precompute_powers(coeff_expressions1[i], 8, existing_aux_dict_precomputed_powers, true);
}
for (int i = 0; i <= degree2; ++i) {
precomputed_powers["B" + to_string(i)] = precompute_powers(coeff_expressions2[i], 8, existing_aux_dict_precomputed_powers, true);
}
for (int i = 0; i <= degree3; ++i) {
precomputed_powers["C" + to_string(i)] = precompute_powers(coeff_expressions3[i], 8, existing_aux_dict_precomputed_powers, true);
}
// Precompute powers
unordered_map<string, unordered_map<int, RCP<const Basic>>> precomputed_powers;
for (int i = 0; i <= degree1; ++i) {
precomputed_powers["A" + to_string(i)] = precompute_powers(coeff_expressions1[i], 8, existing_aux_dict_precomputed_powers, true);
}
for (int i = 0; i <= degree2; ++i) {
precomputed_powers["B" + to_string(i)] = precompute_powers(coeff_expressions2[i], 8, existing_aux_dict_precomputed_powers, true);
}
for (int i = 0; i <= degree3; ++i) {
precomputed_powers["C" + to_string(i)] = precompute_powers(coeff_expressions3[i], 8, existing_aux_dict_precomputed_powers, true);
}

// Define symbolic variables
RCP<const Basic> x = symbol("x");
RCP<const Basic> y = symbol("y");
RCP<const Basic> t = symbol("t");
RCP<const Basic> z = symbol("z");
// Define symbolic variables
RCP<const Basic> x = symbol("x");
RCP<const Basic> y = symbol("y");
RCP<const Basic> t = symbol("t");
RCP<const Basic> z = symbol("z");

// Compute the symbolic basis functions
auto continuous_bezier_expr1 = bernstein_basis_functions_symbolic_continuous_control(x, degree1, coefficients_A);
auto continuous_bezier_expr2 = bernstein_basis_functions_symbolic_continuous_control(y, degree2, coefficients_B);
auto continuous_bezier_expr3 = bernstein_basis_functions_symbolic_continuous_control(t, degree3, coefficients_C);
// Compute the symbolic basis functions
auto continuous_bezier_expr1 = bernstein_basis_functions_symbolic_continuous_control(x, degree1, coefficients_A);
auto continuous_bezier_expr2 = bernstein_basis_functions_symbolic_continuous_control(y, degree2, coefficients_B);
auto continuous_bezier_expr3 = bernstein_basis_functions_symbolic_continuous_control(t, degree3, coefficients_C);

RCP<const Basic> bezier_expr1, bezier_expr2, bezier_expr3;
RCP<const Basic> bezier_expr1, bezier_expr2, bezier_expr3;

// Combine the two Bézier functions
auto combined_continuous_bottom_expr = expand(add(continuous_bezier_expr1, continuous_bezier_expr2));
// Combine the two Bézier functions
auto combined_continuous_bottom_expr = expand(add(continuous_bezier_expr1, continuous_bezier_expr2));

// Create the power of the third Bézier function
auto bezier_continuous_expr3_2 = expand(pow(continuous_bezier_expr3, integer(2)));
// Create the power of the third Bézier function
auto bezier_continuous_expr3_2 = expand(pow(continuous_bezier_expr3, integer(2)));

// Substitute combined_continuous_bottom_expr for t in bezier_continuous_expr3_2
map_basic_basic substitutions;
substitutions[t] = combined_continuous_bottom_expr;
auto substituted_expr = expand(bezier_continuous_expr3_2->subs(substitutions));
// Substitute combined_continuous_bottom_expr for t in bezier_continuous_expr3_2
map_basic_basic substitutions;
substitutions[t] = combined_continuous_bottom_expr;
auto substituted_expr = expand(bezier_continuous_expr3_2->subs(substitutions));

// Substitute precomputed powers in substituted_expr
auto final_expr = substitute_precomputed_powers(substituted_expr, precomputed_powers);
// Substitute precomputed powers in substituted_expr
auto final_expr = substitute_precomputed_powers(substituted_expr, precomputed_powers);

map_basic_basic substitutions_A;
map_basic_basic substitutions_A;

// Substitute coefficients in substituted_expr
for (int i = 0; i <= degree1; ++i) {
substitutions_A[symbol("A" + to_string(i))] = coeff_expressions1[i];
}
// Substitute coefficients in substituted_expr
for (int i = 0; i <= degree1; ++i) {
substitutions_A[symbol("A" + to_string(i))] = coeff_expressions1[i];
}

substituted_expr = final_expr->subs(substitutions_A);
substituted_expr = final_expr->subs(substitutions_A);

map_basic_basic substitutions_C;
map_basic_basic substitutions_C;

for (int i = 0; i <= degree3; ++i) {
substitutions_C[symbol("C" + to_string(i))] = coeff_expressions3[i];
}
substituted_expr = substituted_expr->subs(substitutions_C);
for (int i = 0; i <= degree3; ++i) {
substitutions_C[symbol("C" + to_string(i))] = coeff_expressions3[i];
}
substituted_expr = substituted_expr->subs(substitutions_C);

map_basic_basic substitutions_B;
map_basic_basic substitutions_B;

for (int i = 0; i <= degree2; ++i) {
substitutions_B[symbol("B" + to_string(i))] = coeff_expressions2[i];
}
for (int i = 0; i <= degree2; ++i) {
substitutions_B[symbol("B" + to_string(i))] = coeff_expressions2[i];
}

auto final_substituted_expr = expand(substituted_expr->subs(substitutions_B));
auto final_substituted_expr = expand(substituted_expr->subs(substitutions_B));

// Create auxiliary variables
auto result = apply_aux_variables(final_substituted_expr, existing_aux_dict_precomputed_powers, false);
final_substituted_expr = result.first;
existing_aux_dict_precomputed_powers = result.second;
// Create auxiliary variables
auto result = apply_aux_variables(final_substituted_expr, existing_aux_dict_precomputed_powers, false);
final_substituted_expr = result.first;
existing_aux_dict_precomputed_powers = result.second;

// Now define the z^2
auto z_squared= pow(z, integer(2));
// Now define the z^2
auto z_squared= pow(z, integer(2));

// Now define the middle expression
auto middle_expression = mul(mul(z, integer(-2)), continuous_bezier_expr3);
// Now define the middle expression
auto middle_expression = mul(mul(z, integer(-2)), continuous_bezier_expr3);

map_basic_basic substitutions_middle;
substitutions_middle[t] = combined_continuous_bottom_expr;
auto substituted_expr_middle = expand(middle_expression->subs(substitutions_middle));
map_basic_basic substitutions_middle;
substitutions_middle[t] = combined_continuous_bottom_expr;
auto substituted_expr_middle = expand(middle_expression->subs(substitutions_middle));

// Substitute precomputed powers in substituted_expr
auto final_expr_middle = substitute_precomputed_powers(substituted_expr_middle, precomputed_powers);
// Substitute precomputed powers in substituted_expr
auto final_expr_middle = substitute_precomputed_powers(substituted_expr_middle, precomputed_powers);

final_expr_middle = final_expr_middle->subs(substitutions_A);
final_expr_middle = final_expr_middle->subs(substitutions_A);

final_expr_middle = final_expr_middle->subs(substitutions_C);
final_expr_middle = final_expr_middle->subs(substitutions_C);

final_expr_middle = expand(final_expr_middle->subs(substitutions_B));
final_expr_middle = expand(final_expr_middle->subs(substitutions_B));

// Create auxiliary variables
result = apply_aux_variables(final_expr_middle, existing_aux_dict_precomputed_powers, false);
final_expr_middle = result.first;
aux_dict_final = result.second;
// Create auxiliary variables
result = apply_aux_variables(final_expr_middle, existing_aux_dict_precomputed_powers, false);
final_expr_middle = result.first;
aux_dict_final = result.second;

// Putting it all together:
aux_all_sub_expressions_equation = add(add(z_squared, final_expr_middle), final_substituted_expr);
// Putting it all together:
aux_all_sub_expressions_equation = add(add(z_squared, final_expr_middle), final_substituted_expr);

// Filter the auxiliary dictionary
filter_aux_dict(aux_all_sub_expressions_equation, aux_dict_final);
}
// Filter the auxiliary dictionary
filter_aux_dict(aux_all_sub_expressions_equation, aux_dict_final);
// }


auto unique_terms = extract_unique_xyz_terms(aux_all_sub_expressions_equation);
Expand Down Expand Up @@ -1772,7 +1722,7 @@ compute_mse_with_penalty(int d1, int d2, int d3, int m1, int m2, int m3, double

// Save the current state if a save_filename is provided
if (!save_filename.empty()) {
save_data_2_layer(symbolic_sum_no_mean, x_data_size, aux_all_sub_expressions_equation, aux_dict_final, coefficients_plus1, coefficients_plus2, coefficients_plus3, save_filename);
save_data_2_layer(symbolic_sum_no_mean, x_data_size, save_filename);
}

return std::make_tuple(sse_with_penalty_str, aux_dict_str, coeffs_plus1_str, coeffs_plus2_str, coeffs_plus3_str);
Expand Down

0 comments on commit 7fb9eb8

Please sign in to comment.