-
Notifications
You must be signed in to change notification settings - Fork 16
/
DASTC.py
executable file
·251 lines (204 loc) · 10.4 KB
/
DASTC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import tensorflow as tf
from network.nn import *
from network.ControlGen import Model as BaseModel
class Model(BaseModel):
def __init__(self, args, vocab):
super().__init__(args, vocab)
def build_placeholder(self):
self.dropout = tf.placeholder(tf.float32,
name='dropout')
# target_data
self.td_batch_len = tf.placeholder(tf.int32,
name='td_batch_len')
self.td_enc_inputs = tf.placeholder(tf.int32, [None, None], #size * len
name='td_enc_inputs')
self.td_dec_inputs = tf.placeholder(tf.int32, [None, None],
name='td_dec_inputs')
self.td_targets = tf.placeholder(tf.int32, [None, None],
name='td_targets')
self.td_dec_mask = tf.placeholder(tf.float32, [None, None],
name='td_dec_mask')
self.td_labels = tf.placeholder(tf.float32, [None],
name='td_labels')
self.td_enc_lens = tf.placeholder(tf.float32, [None],
name='td_enc_lens')
# source_data
self.sd_batch_len = tf.placeholder(tf.int32,
name='sd_batch_len')
self.sd_enc_inputs = tf.placeholder(tf.int32, [None, None], #size * len
name='sd_enc_inputs')
self.sd_dec_inputs = tf.placeholder(tf.int32, [None, None],
name='sd_dec_inputs')
self.sd_targets = tf.placeholder(tf.int32, [None, None],
name='sd_targets')
self.sd_dec_mask = tf.placeholder(tf.float32, [None, None],
name='sd_dec_mask')
self.sd_labels = tf.placeholder(tf.float32, [None],
name='sd_labels')
self.sd_enc_lens = tf.placeholder(tf.float32, [None],
name='sd_enc_lens')
def build_model(self, args):
outputs = self.style_transfer_model(args, self.td_enc_inputs, self.td_dec_inputs,
self.td_targets, self.td_dec_mask, self.td_labels, self.td_enc_lens,
scope = 'target')
self.td_loss_rec, self.td_loss_d, self.td_loss_g, self.td_tsf_ids, self.td_rec_ids = outputs
outputs = self.style_transfer_model(args, self.sd_enc_inputs, self.sd_dec_inputs,
self.sd_targets, self.sd_dec_mask, self.sd_labels, self.sd_enc_lens,
scope = 'source')
self.sd_loss_rec = outputs
# optimizer
self.loss_rec = self.td_loss_rec + self.sd_loss_rec
self.loss_g = self.td_loss_g
self.loss_d = self.td_loss_d
self.loss = self.loss_rec + self.rho * self.loss_g
theta_eg = retrive_var(['encoder_decoder'])
theta_d = retrive_var(['discriminator'])
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
grad, _ = zip(*opt.compute_gradients(self.loss, theta_eg))
grad, _ = tf.clip_by_global_norm(grad, 30.0)
self.optimize_tot = opt.apply_gradients(zip(grad, theta_eg))
self.optimize_rec = opt.minimize(self.loss_rec, var_list=theta_eg)
self.optimize_d = opt.minimize(self.loss_d, var_list=theta_d)
def style_transfer_model(self, args, enc_input_ids,
dec_input_ids, targets, dec_mask, labels, enc_lens, scope = None):
# auto-encoder
with tf.variable_scope('encoder_decoder', reuse=tf.AUTO_REUSE):
embedding = tf.get_variable('embedding', initializer=self.word_init)
enc_inputs = tf.nn.embedding_lookup(embedding, enc_input_ids)
dec_inputs = tf.nn.embedding_lookup(embedding, dec_input_ids)
with tf.variable_scope('projection'):
# style information
projection = {}
projection['W'] = tf.get_variable('W', [self.dim_h, self.vocab_size])
projection['b'] = tf.get_variable('b', [self.vocab_size])
encoder = self.create_cell(self.dim_h, args.n_layers, self.dropout, 'encoder')
decoder = self.create_cell(self.dim_h, args.n_layers, self.dropout, 'decoder')
if scope == 'source':
loss_rec = self.reconstruction_lang(
encoder, enc_inputs, labels,
decoder, dec_inputs, targets, dec_mask, projection)
return loss_rec
loss_rec, origin_info, transfer_info = self.reconstruction(
encoder, enc_inputs, labels,
decoder, dec_inputs, targets, dec_mask, projection)
_, soft_tsf_ids, rec_ids, tsf_ids = self.run_decoder(
decoder, dec_inputs, embedding, projection, origin_info, transfer_info)
# discriminator
with tf.variable_scope("discriminator"):
with tf.variable_scope(scope):
classifier_embedding = tf.get_variable('embedding', initializer=self.word_init)
# remove bos, use dec_inputs to avoid noises adding into enc_inputs
real_sents = tf.nn.embedding_lookup(classifier_embedding, dec_input_ids[:, 1:])
fake_sents = tf.tensordot(soft_tsf_ids, classifier_embedding, [[2], [0]])
fake_sents = fake_sents[:, :-1, :] # make the dimension the same as real sents
# mask the sequences
mask = tf.sequence_mask(enc_lens, self.max_len - 1, dtype = tf.float32)
mask = tf.expand_dims(mask, -1)
real_sents *= mask
fake_sents *= mask
loss_d, loss_g = self.run_discriminator(real_sents, fake_sents, labels, args)
return [loss_rec, loss_d, loss_g, tsf_ids, rec_ids]
def reconstruction_lang(self, encoder, enc_inputs, labels,
decoder, dec_inputs, targets, dec_mask, projection):
labels = tf.reshape(labels, [-1, 1])
_, latent_vector = tf.nn.dynamic_rnn(encoder, enc_inputs,
scope='encoder', dtype=tf.float32)
# construct new latent according to content and style vectors
latent_vector = latent_vector[:, self.dim_y:]
batch_size = tf.shape(labels)[0]
rec_style = tf.get_variable('rec_style', [self.dim_y], dtype=tf.float32)
rec_style = tf.tile(tf.expand_dims(rec_style, 0), [batch_size, 1])
latent_vector = tf.concat([rec_style, latent_vector], 1)
hiddens, _ = tf.nn.dynamic_rnn(decoder, dec_inputs,
initial_state=latent_vector, scope='decoder')
hiddens = tf.nn.dropout(hiddens, self.dropout)
hiddens = tf.reshape(hiddens, [-1, self.dim_h])
logits = tf.matmul(hiddens, projection['W']) + projection['b']
rec_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.reshape(targets, [-1]), logits=logits)
rec_loss *= tf.reshape(dec_mask, [-1])
batch_size = tf.shape(labels)[0]
rec_loss = tf.reduce_sum(rec_loss) / tf.to_float(batch_size)
return rec_loss
def _make_feed_dict(self, td_batch, sd_batch, mode='train'):
feed_dict = {}
if mode == 'train':
dropout = self.dropout_rate
else:
dropout = 1.0
feed_dict[self.dropout] = dropout
# target data
if td_batch is not None:
feed_dict[self.td_batch_len] = td_batch.batch_len
feed_dict[self.td_enc_inputs] = td_batch.enc_batch
feed_dict[self.td_dec_inputs] = td_batch.dec_batch
feed_dict[self.td_labels] = td_batch.labels
feed_dict[self.td_enc_lens] = td_batch.enc_lens
feed_dict[self.td_targets] = td_batch.target_batch
feed_dict[self.td_dec_mask] = td_batch.dec_padding_mask
# source data
if sd_batch is not None:
feed_dict[self.sd_batch_len] = sd_batch.batch_len
feed_dict[self.sd_enc_inputs] = sd_batch.enc_batch
feed_dict[self.sd_dec_inputs] = sd_batch.dec_batch
feed_dict[self.sd_labels] = sd_batch.labels
feed_dict[self.sd_enc_lens] = sd_batch.enc_lens
feed_dict[self.sd_targets] = sd_batch.target_batch
feed_dict[self.sd_dec_mask] = sd_batch.dec_padding_mask
return feed_dict
def run_train_step(self, sess, td_batch, sd_batch, accumulator, epoch = None):
"""Runs one training iteration. Returns a dictionary containing train op,
summaries, loss, global_step and (optionally) coverage loss.
"""
feed_dict = self._make_feed_dict(td_batch, sd_batch)
if epoch > self.pretrain_epochs:
results1 = {'td_loss_d': 0.0}
else:
to_return = {
'td_loss_d': self.td_loss_d,
'optimize_d': self.optimize_d,
}
results1 = sess.run(to_return, feed_dict)
if epoch > self.pretrain_epochs:
optimize = self.optimize_tot
else:
optimize = self.optimize_rec
to_return = {
'sd_loss_rec': self.sd_loss_rec,
'td_loss_rec': self.td_loss_rec,
'td_loss_g': self.td_loss_g,
'optimize': optimize,
}
results2 = sess.run(to_return, feed_dict)
results = {**results1, **results2}
accumulator.add([results[name] for name in accumulator.names])
def run_eval_step(self, sess, batch, domain=None):
if domain == 'source':
feed_dict = self._make_feed_dict(None, batch, mode = 'eval')
to_return = {
'rec_ids': self.sd_rec_ids,
'tsf_ids': self.sd_tsf_ids,
'sd_loss_rec': self.sd_loss_rec,
}
elif domain == 'target':
feed_dict = self._make_feed_dict(batch, None, mode = 'eval')
to_return = {
'rec_ids': self.td_rec_ids,
'tsf_ids': self.td_tsf_ids,
'td_loss_rec': self.td_loss_rec,
'td_loss_g': self.td_loss_g,
'td_loss_d': self.td_loss_d,
}
else:
raise ValueError('Wrong domain name: %s.' % domain)
return sess.run(to_return, feed_dict)
def get_output_names(self, domain=None):
if domain == 'source':
return ['sd_loss_rec']
elif domain == 'target':
return ['td_loss_rec', 'td_loss_g', 'td_loss_d']
elif domain == 'all':
return ['td_loss_rec', 'td_loss_g', 'td_loss_d',
'sd_loss_rec']
else:
raise ValueError('Wrong domain name: %s.' % domain)