forked from Yu-Fangxu/COLD-Attack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoding_control.py
365 lines (311 loc) · 16.5 KB
/
decoding_control.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import torch
import torch.nn.functional as F
import numpy as np
import time
import wandb
from nltk import tokenize
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from transformers import AutoModelForCausalLM, AutoTokenizer
from util import *
import time
import seaborn as sns
import copy
from matplotlib import pyplot as plt
stop_words = set(stopwords.words('english'))
def decode(model, tokenizer, device, x="", z="", key_word="", constraints=None, args=None, sys_prompt=None, prefix=None, model_back=None, zz=None):
'''
x: left context (prompt in lexical lexical task)
z: optimization target (original ending in counterfactual task)
constraints: (constraint set in lexical constrained task)
'''
model.eval()
if args.control_type == "sentiment":
control = "Write the output in an extremely exciting way. "
elif args.control_type == "lexical":
# control = "Write the output with the following keywords: Sure, here, is, but do not include the following keywords: cannot"
control = "The output written MUST include the following keywords: "
keywords = key_word.split()
for k in keywords:
control += k
control += ", "
control = control.strip().strip(",")
control += ". "
print(control)
elif args.control_type == "style":
control = "Write the output as a Twitter post. "
elif args.control_type == "format":
control = "Write the output in a JSON format. "
words = ['sorry', 'apologetic', 'regretful',
'unethical', 'immoral', 'dishonest',
'illegal', 'unlawful', 'illicit',
'harmful', 'detrimental', 'injurious',
'afraid', 'fearful', 'apprehensive',
'cannot', 'unable',
'apologize', 'regret',
'unable', 'incapable', 'powerless',
'not', 'no', 'nay',
'responsible', 'accountable', 'liable',
'ethic', 'morality', 'moral',
'legal', 'lawful', 'legitimate',
'ethical', 'moral', 'principled',
'fulfill', 'accomplish', 'achieve',
'just', 'fair', 'equitable',
'trustworthy', 'reliable', 'dependable',
'repentant', 'penitent', 'remorseful',
'unprincipled', 'unscrupulous', 'corrupt',
'unlawful', 'prohibited', 'forbidden',
'noxious', 'destructive', 'damaging',
'timid', 'terrified', 'panicked',
'unable', 'incompetent', 'incapacitated',
'prohibited', 'banned', 'forbidden',
'deny', 'refuse', 'decline',
'answerable', 'culpable', 'answerable',
'moral code', 'values', 'principles',
'legitimate', 'valid', 'authorized',
'virtuous', 'righteous', 'honorable',
'satisfy', 'complete', 'execute',
'fair-minded', 'impartial', 'equitable',
'reliable', 'trustable', 'faithful', 'invalid','safe', 'not', "can't", "but", "against","0", "1", "2",
"3", "4", "5", "6", "7", "8", "9"]
# 将列表中的单词转换为小写形式
lowercase_words = [word.upper() for word in words]
bad_words = words + lowercase_words
bad_words = ' '.join(bad_words)
BIG_CONST = 1e10
if args.use_sysprompt:
x_sys = sys_prompt + x
x_ = tokenizer.encode(x_sys)[1:]
else:
x_ = tokenizer.encode(x)[1:]
x_t = torch.tensor(x_, device=device, dtype=torch.long)
x_onehot = one_hot(x_t, dimension=tokenizer.vocab_size)
# repeat batch_size times
x_t = x_t.unsqueeze(0).repeat(args.batch_size, 1)
x_onehot = x_onehot.repeat(args.batch_size, 1, 1)
target_mask = None
x_mask = None
# extract keywords:
target_ = tokenizer.encode(z)[1:] # delete the "." token we appended before
target_t = torch.tensor(target_, device=device, dtype=torch.long)
target_onehot = one_hot(target_t, dimension=tokenizer.vocab_size)
target_onehot = target_onehot.repeat(args.batch_size, 1, 1)
target_t = target_t.unsqueeze(0).repeat(args.batch_size, 1)
length = args.length
if length <= 0:
length = target_t.shape[1] - length
if args.verbose:
print("x:\t|%s|\nz:\t|%s|\nlength:\t%d\ncontrol:\t%s" % (
tokenizer.decode(x_), tokenizer.decode(target_), length, control))
# target_mask: [batch_size, vocab_size]
target_words = word_tokenize(z[:]) # delete the ". " token we appended before
target_nonstop_words = [w.lower() for w in target_words if w.lower() not in stop_words and w.isalnum()]
target_nonstop_words += [target_words[0]] # add the first token
target_nonstop_words = ' ' + ' '.join(target_nonstop_words)
target_nonstop_ = tokenizer.encode(target_nonstop_words)
print('|' + target_nonstop_words + '|')
target_mask = np.zeros([tokenizer.vocab_size])
target_mask[target_nonstop_] = 1.
target_mask = torch.tensor(target_mask, device=device)
target_mask = target_mask.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, length, 1)
control_ = tokenizer.encode(control)[1:] # delete the "." token we appended before
control_t = torch.tensor(control_, device=device, dtype=torch.long)
control_onehot = one_hot(control_t, dimension=tokenizer.vocab_size)
control_onehot = control_onehot.repeat(args.batch_size, 1, 1)
control_t = control_t.unsqueeze(0).repeat(args.batch_size, 1)
###################################################
x_ = tokenizer.encode(x)[1:] # delete the "." token we appended before
x_t = torch.tensor(x_, device=device, dtype=torch.long)
x_onehot = one_hot(x_t, dimension=tokenizer.vocab_size)
x_onehot = x_onehot.repeat(args.batch_size, 1, 1)
x_t = x_t.unsqueeze(0).repeat(args.batch_size, 1)
length = args.length
if length <= 0:
length = x_t.shape[1] - length
x_words = tokenizer.encode(bad_words)
x_mask = np.zeros([tokenizer.vocab_size])
x_mask[x_words] = 1.
x_mask = torch.tensor(x_mask, device=device)
bad_mask = x_mask.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, length, 1)
bad_mask = torch.ones_like(bad_mask, device=device) - bad_mask
bad_words_ = tokenizer.encode(bad_words)[:] # delete the "." token we appended before
bad_words_t = torch.tensor(bad_words_, device=device, dtype=torch.long)
bad_words_onehot = one_hot(bad_words_t, dimension=tokenizer.vocab_size)
bad_words_onehot = bad_words_onehot.repeat(args.batch_size, 1, 1)
bad_words_t = bad_words_t.unsqueeze(0).repeat(args.batch_size, 1)
###################################################
if args.init_mode == 'original':
init_logits = initialize(model, x_t, length, args.init_temp, args.batch_size ,device, tokenizer)
else:
init_logits = target_onehot / 0.01
init_logits = init_logits[:, :length, :]
if length > init_logits.shape[1]:
init_logits = torch.cat(
[init_logits,
torch.zeros([args.batch_size, length - init_logits.shape[1], tokenizer.vocab_size], device=device)],
dim=1)
text, _, _ = get_text_from_logits(init_logits, tokenizer)
for bi in range(args.batch_size):
print("[initial]: %s" % (text[bi]))
if args.wandb:
wandb.init(
project='args.mode' + str(int(round(time.time() * 1000))),
config=args)
y_logits = init_logits
epsilon = torch.nn.Parameter(torch.zeros_like(y_logits))
epsilon.requires_grad = True
if args.prefix_length > 0:
optim = torch.optim.Adam([epsilon, prefix_logits], lr=args.stepsize)
else:
optim = torch.optim.Adam([epsilon], lr=args.stepsize)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=args.stepsize_iters,
gamma=args.stepsize_ratio)
frozen_len = args.frozen_length
y_logits_ = None
noise_std = 0.0
## Encode x beforehand
assert args.prefix_length <= 0, "The current code does not support prefix-length > 0"
soft_forward_x = x_onehot[:, -1:, :] # The last token of x is used in soft_forward
if x_t.shape[1] == 1:
x_model_past = None
else:
x_model_outputs = model(x_t[:, :-1], use_cache=True)
x_model_past = x_model_outputs.past_key_values
# For right to left model
mask_t = None
for iter in range(args.num_iters):
optim.zero_grad()
y_logits_ = y_logits + epsilon
soft_forward_y = y_logits_ / 0.001
if args.straight_through:
if mask_t is None:
soft_forward_y = (y_logits_.detach() / 0.001 - y_logits_).detach() + y_logits_
else:
soft_forward_y = top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=x_mask, bad_mask=None) / 0.001
if args.fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
y_logits_t = soft_forward(model, soft_forward_x, soft_forward_y, args.topk, extra_mask=x_mask, x_past=x_model_past, bad_mask=None) # without gradient
else:
y_logits_t = soft_forward(model, soft_forward_x, soft_forward_y, args.topk, extra_mask=x_mask, x_past=x_model_past, bad_mask=None)
if args.topk == 0:
mask_t = None
else:
_, indices_t = torch.topk(y_logits_t, args.topk)
mask_t = torch.zeros_like(y_logits_t).scatter_(2, indices_t, 1)
flu_loss = soft_nll(
top_k_filter_3d(y_logits_t / args.output_lgt_temp, args.topk, extra_mask=x_mask, bad_mask=None),
y_logits_ / args.input_lgt_temp)
# Attack Loss
soft_forward_y_ = (y_logits_.detach() / 0.001 - y_logits_).detach() + y_logits_
if args.fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
xyz_target_logits, xyz_length = soft_forward_xyz_target(model, soft_forward_x, soft_forward_y_, control_onehot, target_onehot)
else:
xyz_target_logits, xyz_length = soft_forward_xyz_target(model, soft_forward_x, soft_forward_y_, control_onehot, target_onehot)
# print(xyz_target_logits.grad)
# Reshaping
bz = args.batch_size
lg = xyz_target_logits.shape[1]
st = xyz_length - 1
ed = xyz_target_logits.shape[1] - 1
xyz_target_logits = xyz_target_logits.view(-1, xyz_target_logits.shape[-1])
target_logits = torch.cat([xyz_target_logits[bi * lg + st:bi * lg + ed, :] for bi in range(bz)], dim=0)
# print(target_logits.shape)
c_loss_1 = torch.nn.CrossEntropyLoss(reduction='none')(
target_logits,
target_t.view(-1))
c_loss_1 = c_loss_1.view(args.batch_size, -1).mean(-1)
# future token loss
if args.fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_x, soft_forward_y_, control_onehot)
else:
xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_x, soft_forward_y_, control_onehot)
# Reshaping
bz = args.batch_size
lg = xyz_logits.shape[1]
st = xy_length - 1
ed = xyz_logits.shape[1] - 1
xyz_logits = xyz_logits.view(-1, xyz_logits.shape[-1])
z_logits = torch.cat([xyz_logits[bi * lg + st:bi * lg + ed, :] for bi in range(bz)], dim=0)
# print(z_logits.shape)
c_loss_2 = torch.nn.CrossEntropyLoss(reduction='none')(
z_logits,
control_t.view(-1))
c_loss_2 = c_loss_2.view(args.batch_size, -1).mean(-1)
c_loss_3 = batch_log_bleulosscnn_ae(
decoder_outputs=y_logits_.transpose(0, 1),
target_idx=bad_words_t,
ngram_list=[1]
)
loss = args.goal_weight * c_loss_1 + 1.0 * flu_loss - args.rej_weight * c_loss_3 + 100 * c_loss_2
loss = loss.mean()
if iter < args.num_iters - 1: # so that the mask_t at the last iteration will not change
loss.backward()
optim.step()
scheduler.step() # turn off the scheduler
last_lr = scheduler.get_last_lr()[0]
if args.verbose and ((iter + 1) % args.print_every == 0 or iter == 0 or iter + 1 == args.num_iters):
text, _, last_text_ids = decode_with_model_topk(
model, y_logits_, args.topk, soft_forward_x, x_model_past, tokenizer, extra_mask=None, bad_mask=None)
text_post = text
for bi in range(args.batch_size):
prompt = x + " " + text_post[bi] + control
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(input_ids.shape)
print(f"\n Output of the model:\n")
output_ids = model.generate(inputs=input_ids, temperature=0.7, max_length = 512, do_sample=True, top_k=args.topk)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
print(
"%d, loss: %.4f,flu_loss: %.4f, attack_loss: %.4f, consistency loss: %.4f, rejection_loss: %.4f, lr: %.4f, |%s|" % (
iter + 1, loss.item(), flu_loss[bi].item(), c_loss_1[bi].item(), c_loss_2[bi].item(), c_loss_3[bi].item(), last_lr, text_post[bi]))
## noise
if iter < args.num_iters - 1:
large_noise_iters = [int(_) for _ in args.large_noise_iters.split(',')]
large_gs_stds = [float(_) for _ in args.large_gs_std.split(',')]
noise_std = 0.
if iter % args.noise_iters == 0:
noise_last = True
for ni in range(len(large_noise_iters)):
if iter < large_noise_iters[ni]:
noise_last = False
break
if noise_last:
noise_std = args.gs_std
else:
noise_std = large_gs_stds[ni]
noise = torch.normal(mean=args.gs_mean, std=noise_std, size=epsilon.size(),
device='cuda', requires_grad=False)
if args.win_anneal_iters >= 0 and iter >= args.win_anneal_iters:
zeros = torch.zeros_like(noise)
noise_mix = torch.cat([zeros[:, :frozen_len], noise[:, frozen_len:]], dim=1)
y_logits = y_logits + noise_mix
else:
y_logits = y_logits + noise
# print(y_logits_.topk(k=10, dim=-1))
text, _, last_text_ids = decode_with_model_topk(
model, y_logits_, args.topk, soft_forward_x, x_model_past, tokenizer, extra_mask=None, bad_mask=None)
# text_post = post_process(last_text_ids, model, args.max_length, args.length, tokenizer, device)
text_post = text
decoded_text = []
for bi in range(args.batch_size):
# if "Llama" not in args.pretrained_model:
# prompt = conv_template_user + x + " " + text_post[bi] + control + conv_template_ai
# else:
prompt = x + " " + text_post[bi] + control
# print(prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# print(input_ids.shape)
# print(f"\n Output of the model:\n")
output_ids = model.generate(inputs=input_ids, temperature=0.7, max_length = 512, pad_token_id=tokenizer.pad_token_id, do_sample=True, top_k=args.topk)
output_ids = output_ids[:, input_ids.shape[1]:]
text_dec = tokenizer.decode(output_ids[0], skip_special_tokens=True)
decoded_text.append(text_dec.strip())
last_rank_loss = model(input_ids=last_text_ids, labels=last_text_ids).loss
last_rank_loss = last_rank_loss.detach().clone().data.cpu().numpy()
# text_post = post_process(last_text_ids, model, args.max_length, args.length, tokenizer, device)
# text_post_ids = model.generate(inputs=last_text_ids, temperature=0.7, max_length = args.max_length, pad_token_id=tokenizer.pad_token_id, do_sample=True, top_k=args.topk)
# text_post = tokenizer.decode(text_post_ids[0], skip_special_tokens=True)
ppl_last = np.exp(last_rank_loss)
prompt_with_adv = [x + " " + t for t in text_post]
return ppl_last, text, text_post, decoded_text, prompt_with_adv