From 7cdac144c311ef97f125d8a5e12198b9f973d4cd Mon Sep 17 00:00:00 2001 From: Jeff Wu Date: Thu, 14 Feb 2019 11:34:14 -0800 Subject: [PATCH] fix bug and remove f strings --- src/generate_unconditional_samples.py | 4 ++-- src/interactive_conditional_samples.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index 98efe8165..205a754f1 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -28,7 +28,7 @@ def sample_model( if length is None: length = hparams.n_ctx elif length > hparams.n_ctx: - raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tf.Session(graph=tf.Graph()) as sess: output = sample.sample_sequence( @@ -49,7 +49,7 @@ def sample_model( generated += batch_size text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - print(f"{text}") + print(text) if __name__ == '__main__': fire.Fire(sample_model) diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 50381176f..def38dd2e 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -31,7 +31,7 @@ def interact_model( if length is None: length = hparams.n_ctx // 2 elif length > hparams.n_ctx: - raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) @@ -40,7 +40,7 @@ def interact_model( context=context, batch_size=batch_size, temperature=temperature, top_k=top_k - )[:, 1:] + ) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) @@ -48,17 +48,20 @@ def interact_model( while True: raw_text = input("Model prompt >>> ") + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") context_tokens = enc.encode(raw_text) generated = 0 for _ in range(nsamples // batch_size): out = sess.run(output, feed_dict={ context: [context_tokens for _ in range(batch_size)] - }) + })[:, len(context_tokens):] for i in range(batch_size): generated += 1 text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - print(f"{text}") + print(text) print("=" * 80) if __name__ == '__main__':