Skip to content

Commit

Permalink
Allow loading pretrained model only
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Aug 28, 2019
1 parent e161bbe commit 37b2eae
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,17 @@ def sample_batch():

def load_gpt2(sess,
run_name="run1",
checkpoint_dir="checkpoint"):
"""Loads the model checkpoint into a TensorFlow session
checkpoint_dir="checkpoint",
model_name=None,
model_dir='models'):
"""Loads the model checkpoint or existing model into a TensorFlow session
for repeated predictions.
"""

checkpoint_path = os.path.join(checkpoint_dir, run_name)
if model_name:
checkpoint_path = os.path.join(model_dir, model_name)
else:
checkpoint_path = os.path.join(checkpoint_dir, run_name)

hparams = model.default_hparams()
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
Expand All @@ -364,7 +369,10 @@ def load_gpt2(sess,
saver = tf.compat.v1.train.Saver(allow_empty=True)
sess.run(tf.compat.v1.global_variables_initializer())

print('Loading checkpoint', ckpt)
if model_name:
print('Loading pretrained model', ckpt)
else:
print('Loading checkpoint', ckpt)
saver.restore(sess, ckpt)


Expand Down

0 comments on commit 37b2eae

Please sign in to comment.