-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathista.py
215 lines (174 loc) · 8.3 KB
/
ista.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
__author__ = 'Elad Sofer <[email protected]>'
import copy
import numpy as np
import torch.nn as nn
import torch
import seaborn as sns
from utills import generate_signal, plot_conv_rec_graph, BIM, plot_3d_surface, \
plot_2d_surface, plot_1d_surface, plot_norm_graph, plot_observations
from utills import sig_amount, r_step, eps_min, eps_max, loss3d_res_steps
from visualize_model import LandscapeWrapper
from utills import m, H
sns.set()
np.random.seed(0)
# ISTA configuration
step_size = 0.1
max_iter = 10000
rho = 0.01
eps_threshold = 1e-3
# ISTA
class ISTA(nn.Module, LandscapeWrapper):
"""
Implements the Iterative Shrinkage-Thresholding Algorithm (ISTA) for sparse signal recovery.
Args:
H (torch.Tensor): Sensing matrix.
mu (float): Solver parameter for gradient descent step.
rho (float): Regularization parameter for L1-norm penalty.
max_iter (int): Maximum number of iterations.
eps (float): Convergence threshold.
Attributes:
H (torch.Tensor): Sensing matrix.
rho (float): Regularization parameter for L1-norm penalty.
mu (float): Solver parameter for gradient descent step.
max_iter (int): Maximum number of iterations.
eps (float): Convergence threshold.
s (torch.Tensor): Initial estimate of the sparse signal.
model_params (nn.Parameter): Model parameters used for visualization.
"""
def __init__(self, H, mu, rho, max_iter, eps):
super(ISTA, self).__init__()
# Objective parameters
self.H = H
self.rho = rho
# Solver parameters
self.mu = mu
self.max_iter = max_iter
self.eps = eps
# initial estimate
self.s = None
self.model_params = None
@staticmethod
def shrinkage(x, beta):
"""
Applies the shrinkage operator to the input tensor 'x' with a threshold of 'beta'.
:param x: Input tensor.
:param beta: Threshold value.
:return: Resulting tensor after applying shrinkage.
"""
# Shrinking towards 0 by Beta parameter.
return torch.mul(torch.sign(x), torch.max(torch.abs(x) - beta, torch.zeros((m, 1))))
def forward(self, x):
"""
Performs ISTA reconstruction on the input signal 'x'.
:param x: Input signal to reconstruct. (torch.Tensor)
:return torch.Tensor: Reconstructed sparse signal.
:return list: List of recovery errors at each iteration.
"""
self.s = torch.zeros((H.shape[1], 1))
recovery_errors = []
for ii in range(self.max_iter):
s_prev = self.s
# proximal gradient step
temp = torch.matmul(self.H, s_prev) - x
g_grad = s_prev - torch.mul(self.mu, torch.matmul(self.H.T, temp))
self.s = self.shrinkage(g_grad, np.multiply(self.mu, self.rho))
# cease if convergence achieved
if torch.sum(torch.abs(self.s - s_prev)).item() <= self.eps:
break
# save recovery error
error = self.loss_func(self.s, x)
recovery_errors.append(error)
return self.s, recovery_errors
def set_model_visualization_params(self):
"""
Sets the model parameters for visualization for the visualize_model module to operate.
"""
self.model_params = nn.Parameter(self.s.detach(), requires_grad=False)
def loss_func(self, s, x_sig):
"""
Computes the loss function given the estimated sparse signal 's' and its observation 'x_sig'.
:param s: Estimated sparse signal.
:param x_sig: observation signal x = Hs + w, where w is a Gaussian noise.
:return: Loss value.
"""
return 0.5 * torch.sum((torch.matmul(self.H, s) - x_sig) ** 2).item() + self.rho * s.norm(p=1).item()
@staticmethod
def copy(other):
"""
Creates a deep copy of the 'other' object.
Args: other (ISTA): ISTA object to copy.
Returns: ISTA: Deep copy of the 'other' object.
"""
return copy.deepcopy(other)
@classmethod
def create_ISTA(cls, H=H, step_size=step_size, rho=rho, max_iter=max_iter, eps_threshold=eps_threshold):
"""
Creates an instance of the ISTA class with the specified parameters.
:param H: Sensing matrix.
:param step_size: Solver parameter for gradient descent step.
:param rho: Regularization parameter for L1-norm penalty.
:param max_iter: Maximum number of iterations.
:param eps_threshold: Convergence threshold.
:return: ISTA object.
"""
return cls(H, step_size, rho, max_iter, eps_threshold)
def execute():
"""
Perform a series of operations on generated signals:
1. Generate 'c' (set to 100) signals of the form x_i = Hs + w, where w follows a Gaussian distribution.
2. Perform ISTA reconstruction on each signal x to obtain s^*.
3. Perform BIM adversarial attack with different epsilon values to obtain x_{adv}.
4. Perform ISTA reconstruction on each signal x_{adv} to obtain s_{adv}.
5. Aggregate the L2 norm ||s^* - s^*_{adv}|| for each signal and epsilon value.
6. Plot the loss surfaces in various forms (3D, 2D, 1D) and other related graphs.
"""
signals = []
dist_total = np.zeros((sig_amount, r_step))
radius_vec = np.linspace(eps_min, eps_max, r_step)
for i in range(sig_amount):
signals.append(generate_signal())
##########################################################
for sig_idx, (x_original, s_original) in enumerate(signals):
# ISTA without an attack reconstruction
ISTA_t_model = ISTA.create_ISTA()
s_gt, err_gt = ISTA_t_model(x_original.detach())
print("#### ISTA signal {0} convergence: iterations: {1} ####".format(sig_idx, len(err_gt)))
s_gt = s_gt.detach()
for e_idx, attack_eps in enumerate(radius_vec):
# print("Performing BIM to get Adversarial Perturbation - epsilon: {0}".format(r))
ISTA_adv_model = ISTA.create_ISTA()
adv_x, delta = BIM(ISTA_adv_model, x_original, s_original, eps=attack_eps)
adv_x = adv_x.detach()
s_attacked, err_attacked = ISTA_adv_model(adv_x)
# print("Attacked-ISTA convergence: iterations: {0}".format(len(err_attacked)))
dist_total[sig_idx, e_idx] = (s_gt - s_attacked).norm(2).item()
##########################################################
# np.save('data/stack/version1/matrices/ISTA_total_norm.npy', dist_total)
plot_norm_graph(radius_vec, dist_total.mean(axis=0), fname='ISTA_norm2.pdf')
x = x_original.detach()
plot_observations(adv_x, x, fname="ISTA_observation.pdf")
plot_conv_rec_graph(s_attacked.detach().numpy(), s_gt.detach().numpy(), s_original,
err_attacked, err_gt, fname='ISTA_convergence.pdf')
# Presenting last iteration signal loss surfaces for r=max_eps
ISTA_adv_model.set_model_visualization_params()
ISTA_t_model.set_model_visualization_params()
# Extract loss surface
dir_one, dir_two = ISTA_t_model.get_grid_vectors(ISTA_t_model, ISTA_adv_model)
gt_line = ISTA_t_model.linear_interpolation(model_start=ISTA_t_model, model_end=ISTA_adv_model, x_sig=x,
deepcopy_model=True)
adv_line = ISTA_t_model.linear_interpolation(model_start=ISTA_t_model, model_end=ISTA_adv_model, x_sig=adv_x,
deepcopy_model=True)
# Plotting 1D
plot_1d_surface(gt_line, adv_line, 'ISTA_1D_LOSS.pdf')
Z_gt, Z_adv = ISTA_t_model.random_plane(gt_model=ISTA_t_model, adv_model=ISTA_adv_model,
adv_x=adv_x, x=x,
dir_one=dir_one, dir_two=dir_two,
steps=loss3d_res_steps)
# np.save('data/stack/version1/matrices/ISTA_Z_adv.npy', Z_adv)
# np.save('data/stack/version1/matrices/ISTA_Z_gt.npy', Z_gt)
# Plotting 2D
plot_2d_surface(Z_gt, Z_adv, 'ISTA_2D_LOSS.pdf')
# Plotting 3D - https://jakevdp.github.io/PythonDataScienceHandbook/04.12-three-dimensional-plotting.html
plot_3d_surface(z_adv=Z_adv, z_gt=Z_gt, steps=loss3d_res_steps, fname="ISTA_COMBINED_3D_LOSS.pdf")
if __name__ == '__main__':
execute()