文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 405ac434 by Zhou Enhua

update

parent 134817d3
......@@ -116,3 +116,4 @@ def get_encoder(model_name):
encoder=encoder,
bpe_merges=bpe_merges,
)
......@@ -41,7 +41,7 @@ def sample_model(
"""
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join('../models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
......@@ -61,7 +61,7 @@ def sample_model(
)[:, 1:]
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
ckpt = tf.train.latest_checkpoint(os.path.join('../models', model_name))
saver.restore(sess, ckpt)
generated = 0
......@@ -71,7 +71,7 @@ def sample_model(
generated += batch_size
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
path1 = "/home/stu/pkq/gpt-2/samples/new_try_samples/uncon"
path1 = '../samples/story_unconditional'
path2 = str(generated) + ".txt"
path = os.path.join(path1, path2)
......@@ -82,4 +82,4 @@ def sample_model(
return 1
if __name__ == '__main__':
fire.Fire(sample_model)
sample_model(model_name="345MShort", nsamples=3)
......@@ -18,5 +18,5 @@ def run1(path1):
#run1("/home/stu/pkq/gpt-2/samples/story_start/1_start.txt")
raw_text = ""
raw_text = " "
model2.train(raw_text, 5)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment