文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 405ac434 by Zhou Enhua

update

parent 134817d3
...@@ -116,3 +116,4 @@ def get_encoder(model_name): ...@@ -116,3 +116,4 @@ def get_encoder(model_name):
encoder=encoder, encoder=encoder,
bpe_merges=bpe_merges, bpe_merges=bpe_merges,
) )
...@@ -41,7 +41,7 @@ def sample_model( ...@@ -41,7 +41,7 @@ def sample_model(
""" """
enc = encoder.get_encoder(model_name) enc = encoder.get_encoder(model_name)
hparams = model.default_hparams() hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f: with open(os.path.join('../models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f)) hparams.override_from_dict(json.load(f))
if length is None: if length is None:
...@@ -61,7 +61,7 @@ def sample_model( ...@@ -61,7 +61,7 @@ def sample_model(
)[:, 1:] )[:, 1:]
saver = tf.train.Saver() saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) ckpt = tf.train.latest_checkpoint(os.path.join('../models', model_name))
saver.restore(sess, ckpt) saver.restore(sess, ckpt)
generated = 0 generated = 0
...@@ -71,7 +71,7 @@ def sample_model( ...@@ -71,7 +71,7 @@ def sample_model(
generated += batch_size generated += batch_size
text = enc.decode(out[i]) text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
path1 = "/home/stu/pkq/gpt-2/samples/new_try_samples/uncon" path1 = '../samples/story_unconditional'
path2 = str(generated) + ".txt" path2 = str(generated) + ".txt"
path = os.path.join(path1, path2) path = os.path.join(path1, path2)
...@@ -82,4 +82,4 @@ def sample_model( ...@@ -82,4 +82,4 @@ def sample_model(
return 1 return 1
if __name__ == '__main__': if __name__ == '__main__':
fire.Fire(sample_model) sample_model(model_name="345MShort", nsamples=3)
...@@ -18,5 +18,5 @@ def run1(path1): ...@@ -18,5 +18,5 @@ def run1(path1):
#run1("/home/stu/pkq/gpt-2/samples/story_start/1_start.txt") #run1("/home/stu/pkq/gpt-2/samples/story_start/1_start.txt")
raw_text = "" raw_text = " "
model2.train(raw_text, 5) model2.train(raw_text, 5)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment