Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dicg callback info #539

Merged
merged 2 commits into from
Dec 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions src/dicg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function decomposition_invariant_conditional_gradient(
end

x = x0

if memory_mode isa InplaceEmphasis && !isa(x, Union{Array,SparseArrays.AbstractSparseArray})
# if integer, convert element type to most appropriate float
if eltype(x) <: Integer
Expand Down Expand Up @@ -127,12 +127,12 @@ function decomposition_invariant_conditional_gradient(
gamma = one(phi)

if lazy
if extra_vertex_storage === nothing
v = compute_extreme_point(lmo, gradient, lazy = lazy)
pre_computed_set = [v]
else
pre_computed_set = extra_vertex_storage
end
if extra_vertex_storage === nothing
v = compute_extreme_point(lmo, gradient, lazy=lazy)
pre_computed_set = [v]
else
pre_computed_set = extra_vertex_storage
end
end

if linesearch_workspace === nothing
Expand Down Expand Up @@ -168,16 +168,8 @@ function decomposition_invariant_conditional_gradient(
end

if lazy
d, v, v_index, a, away_index, phi, step_type =
lazy_dicg_step(
x,
gradient,
lmo,
pre_computed_set,
phi,
epsilon,
d;
)
d, v, v_index, a, away_index, phi, step_type =
lazy_dicg_step(x, gradient, lmo, pre_computed_set, phi, epsilon, d;)
else # non-lazy, call the simple and modified
v = compute_extreme_point(lmo, gradient, lazy=lazy)
dual_gap = fast_dot(gradient, x) - fast_dot(gradient, v)
Expand Down Expand Up @@ -205,7 +197,7 @@ function decomposition_invariant_conditional_gradient(
push!(pre_computed_set, v)
end
end

if callback !== nothing
state = CallbackState(
t,
Expand All @@ -223,7 +215,7 @@ function decomposition_invariant_conditional_gradient(
gradient,
step_type,
)
if callback(state) === false
if callback(state, a, v) === false
break
end
end
Expand Down Expand Up @@ -259,7 +251,7 @@ function decomposition_invariant_conditional_gradient(
gradient,
step_type,
)
callback(state)
callback(state, nothing, v)
end
end
return (x=x, v=v, primal=primal, dual_gap=dual_gap, traj_data=traj_data)
Expand Down Expand Up @@ -441,7 +433,7 @@ function blended_decomposition_invariant_conditional_gradient(
gradient,
step_type,
)
if callback(state) === false
if callback(state, a, v) === false
break
end
end
Expand Down Expand Up @@ -477,7 +469,7 @@ function blended_decomposition_invariant_conditional_gradient(
gradient,
step_type,
)
callback(state)
callback(state, nothing, v)
end
end
return (x=x, v=v, primal=primal, dual_gap=dual_gap, traj_data=traj_data)
Expand Down Expand Up @@ -518,19 +510,19 @@ function lazy_dicg_step(
v = compute_extreme_point(lmo, gradient)
grad_dot_v = fast_dot(gradient, v)
# Do lazy inface_point
if grad_dot_a_local - grad_dot_v >= phi / lazy_tolerance &&
grad_dot_a_local - grad_dot_v >= epsilon
if grad_dot_a_local - grad_dot_v >= phi / lazy_tolerance &&
grad_dot_a_local - grad_dot_v >= epsilon
step_type = ST_LAZY
a = a_local
away_index = a_local_loc
else
a = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
end

# Real dual gap promises enough progress.
grad_dot_fw_vertex = fast_dot(v, gradient)
dual_gap = grad_dot_x - grad_dot_fw_vertex

if dual_gap >= phi / lazy_tolerance
d = muladd_memory_mode(memory_mode, d, a, v)
#Lower our expectation for progress.
Expand Down
Loading