Skip to content

Commit

Permalink
Bugfix bisection edge conditions in prox_l1_and_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesfolberth committed Feb 28, 2019
1 parent 9a794ad commit 6043fa8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 31 deletions.
74 changes: 59 additions & 15 deletions mexFiles/prox_l1_and_sum_worker_mex.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ void Proxl1Sum::run(double *y,
// Sort break_points in decreasing order
std::sort(break_points_.rbegin(), break_points_.rend());

// Bisection search to find sum(x) - b == 0
//for (size_t i=0; i<break_points_.size(); ++i) {
// std::cout << "bp[i] = " << break_points_[i] << std::endl;
//}

// Bisection search to solve 0 == sum(x) - b =: h(a)
ptrdiff_t lower_bound = -1;
size_t num_break_points = (zero_ind >= 0) ? 2*n - 2 : 2*n;
ptrdiff_t upper_bound = num_break_points;
Expand All @@ -185,16 +189,16 @@ void Proxl1Sum::run(double *y,
+ static_cast<double>(upper_bound)));

// Check that we're not at either endpoint (which may be invalid indexes)
if (ind == lower_bound) {
if (ind == -1) {
++ind;
} else if (ind == upper_bound) {
} else if (ind == num_break_points) {
--ind;
}

// Evaluate the prox at this offset and check the function value
double a = break_points_[ind];
evaluate_prox(y, x0, n, lambda, a);
double h = evaluate_func(y, n, b);
double h = evaluate_func(y, n, b); // evaluate h(a)

if (zero_ind >= 0) {
h -= y[zero_ind];
Expand All @@ -208,24 +212,59 @@ void Proxl1Sum::run(double *y,
}
}

//std::cout << "lower_bound = " << lower_bound << std::endl;
//std::cout << "upper_bound = " << upper_bound << std::endl;

// Now determine linear part, which we infer from two points.
// If the lower or upper bounds are infinite, we take special care
// If the lower or upper bounds are outside the range of valid break points
// (i.e., they are infinite by the convention of Algorithm 3 of
// https://arxiv.org/pdf/1804.06291.pdf), we take special care
// by using a new "a" that is slightly bigger/lower, respectively.
// Note that lower_bound/upper_bound correspond to indexes of
// the break points vector, which is sorted in _decreasing_ order.
// Thus, to get _lower_ than lower_bound, we need a _higher_ a.
// This is then used to extract the linear part.
double a;
double a_left, a_right; // bracket of the root h(a) == 0
if (lower_bound == -1) {
a = break_points_[upper_bound];
// use a - 10 as lower bound
a = 0.5*(a - 10 + a);
// use a + 10 as right part of bracket
a_left = break_points_[upper_bound];
a_right = a_left + 10;
} else if (upper_bound == num_break_points) {
a = break_points_[lower_bound];
// use a + 10 as upper bound
a = 0.5*(a + a + 10);
// use a - 10 as left part of bracket
a_right = break_points_[lower_bound];
a_left = a_right - 10;
} else {
// general case
a = 0.5*(break_points_[lower_bound] + break_points_[upper_bound]);
a_left = break_points_[lower_bound];
a_right = break_points_[upper_bound];
}
double a = 0.5*(a_left + a_right);
//std::cout << "a = " << a << std::endl;

/*
evaluate_prox(y, x0, n, lambda, a_left);
double h_left = evaluate_func(y, n, b);
if (zero_ind >= 0) {
h_left -= y[zero_ind];
}
evaluate_prox(y, x0, n, lambda, a);
double h_a = evaluate_func(y, n, b);
if (zero_ind >= 0) {
h_a -= y[zero_ind];
}
evaluate_prox(y, x0, n, lambda, a_right);
double h_right = evaluate_func(y, n, b);
if (zero_ind >= 0) {
h_right -= y[zero_ind];
}
std::cout << "h(left) = " << h_left << " "
<< "h(a) = " << h_a << " "
<< "h(right) = " << h_right << std::endl;
*/

// Now we have the support; find the exact value
evaluate_prox(y, x0, n, lambda, a); // to find the support

Expand Down Expand Up @@ -260,6 +299,13 @@ void Proxl1Sum::run(double *y,
y[zero_ind] = 0.;
}

// h(nu) should be 0
//double h_nu = evaluate_func(y, n, b);
//if (zero_ind >= 0) {
// h_nu -= y[zero_ind];
// }
//std::cout << "h(nu) = " << h_nu << std::endl;

}

void Proxl1Sum::evaluate_prox(double *y,
Expand Down Expand Up @@ -326,8 +372,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
}
const double *lambda = mxGetPr(prhs[1]);



// Third input
const double b = mxGetScalar(prhs[2]);

Expand Down
33 changes: 23 additions & 10 deletions mexFiles/tests/test_prox_l1_and_sum.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function [] = test_prox_l1_and_sum()

check_correctness()
time_single_run()
%time_single_run()
%profile_single_run()

end
Expand All @@ -12,29 +12,42 @@

% An example from sparse subspace clustering (SSC)
%rng(271828);
n_rows = 700;
n_cols = 600;
% - Larger mxn randn will test the main/average part of the bisection.
% - To test the upper_bound edge condition, use 2x2 randn with positive b.
% - To test the lower_bound edge condition, use 2x2 randn with negative b.
n_rows = 2;
n_cols = 2;
x = randn(n_rows, n_cols);
x = x(:);
t = rand();
t = rand(n_cols,1);
%t = rand(1,n_cols);
Q = rand();
b = rand();

zeroID = true;
useMex = true;

% Make the prox operators
prox_ref = prox_l1_and_sum_ref(Q, b, n_cols, zeroID);
prox_test = prox_l1_and_sum(Q, b, n_cols, zeroID, useMex);
% JMF 26 Feb 2019: this doesn't appear to be correct in break-point edge cases; trust useMex=false version
%prox_ref = prox_l1_and_sum_ref(Q, b, n_cols, zeroID);
prox_ref = prox_l1_and_sum(Q, b, n_cols, zeroID, false);
prox_test = prox_l1_and_sum(Q, b, n_cols, zeroID, true);

[~,y_ref] = prox_ref(x, t);
[~,y_test] = prox_test(x, t);

%reshape(y_ref, [n_rows n_cols])
%reshape(y_test, [n_rows n_cols])

y_ref = reshape(y_ref, [n_rows n_cols]);
y_test = reshape(y_test, [n_rows n_cols]);

%y_ref
%y_test
%sum(y_ref, 1)
%sum(y_test, 1)

rel_err = norm(y_ref - y_test, 'fro') / norm(y_ref, 'fro');
fprintf('relative error = %1.5e\n', rel_err);

sum_err = norm(sum(y_test,1) - b, 'fro') / abs(b);
fprintf('relative error in sum = %1.5e\n', sum_err)

end

Expand Down
13 changes: 7 additions & 6 deletions prox_l1_and_sum.m
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
nCols = size( x0, 2 ); % allow matrices
LDA = size( brk_pts, 1 );
offsets = (0:nCols-1)*LDA;%i.e., [0, LDA, 2*LDA, ... ];
num_brk_pts = LDA;

lwrBnd = zeros(1,nCols);
uprBnd = (length(brk_pts) + 1)*ones(1,nCols);
Expand Down Expand Up @@ -278,12 +279,12 @@
[a1,a2] = deal( zeros(1,nCols) );
ind = find( lwrBnd == 0 );
a2(ind) = brk_pts( uprBnd(ind) + offsets(ind) );
a1(ind) = a2(ind) - 10;
a1(ind) = a2(ind) + 10;
ind2 = ind;

ind = find( uprBnd == size(brk_pts,1) + 1 );
ind = find( uprBnd == num_brk_pts + 1 );
a1(ind) = brk_pts( lwrBnd(ind) + offsets(ind) );
a2(ind) = a1(ind) + 10;
a2(ind) = a1(ind) - 10;

indOther = setdiff( 1:nCols, [ind2,ind] );
a1(indOther) = brk_pts( lwrBnd(indOther) + offsets(indOther) );
Expand Down Expand Up @@ -372,12 +373,12 @@
[a1,a2] = deal( zeros(1,nCols) );
ind = find( lwrBnd == 0 );
a2(ind) = brk_pts( uprBnd(ind) + offsets(ind) );
a1(ind) = a2(ind) - 10;
a1(ind) = a2(ind) + 10;
ind2 = ind;

ind = find( uprBnd == size(brk_pts,1) + 1 );
ind = find( uprBnd == num_brk_pts + 1 );
a1(ind) = brk_pts( lwrBnd(ind) + offsets(ind) );
a2(ind) = a1(ind) + 10;
a2(ind) = a1(ind) - 10;

indOther = setdiff( 1:nCols, [ind2,ind] );
a1(indOther) = brk_pts( lwrBnd(indOther) + offsets(indOther) );
Expand Down

0 comments on commit 6043fa8

Please sign in to comment.