Skip to content
Snippets Groups Projects
user avatar
zalashub authored
152359c4

Machine Learning Hierarchical Story Generation

Model Reference

Hierarchical Neural Story Generation https://arxiv.org/abs/1805.04833

Instructions

Install the required libraries

pip install -r requirements.txt

Note: you need to only do this if you're running in a local environment and not in e.g. Google Colab

Train the model

python train.py _*args_

Note: Make sure you delete the data.npy and vocab.pkl files in the data/corpus directory before training!
If you're satisfied with the defaults listed below, there is no need to add any arguments to the script call.

Arguments for training and their defaults

  • --data_dir default=_'data/corpus'_
    Data directory containing input.txt
  • --input_encoding default=None
    Character encoding of input.txt (probably recommended to use 'utf_8')
  • --log_dir default=_'logs'_
    Directory containing tensorboard logs.
  • --save_dir default=_'save'_
    Directory to store checkpointed models.
  • --rnn_size default=256
    Size of RNN hidden state.
  • --num_layers default=2
    Number of layers in the RNN.
  • --model default=_'lstm'_
    Type of model. Options: 'rnn, gru, or lstm'
  • --batch_size default=50
    Minibatch size.
  • --seq_length default=25
    RNN sequence length.
  • --num_epochs default=50
  • --save_every default=1000
    Save frequency.
  • --grad_clip default=5.
    Clip gradients at this value.
  • --learning_rate default=0.002
  • --decay_rate default=0.97
  • --gpu_mem default=0.666
    %% of GPU memory to be allocated to the training process. Default is 66.6%.
  • --init_from default=None
    Continue training from saved model at this path. Path must contain files saved by previous training process:
    'config.pkl' : configuration;
    'words_vocab.pkl' : vocabulary definitions;
    'checkpoint' : paths to model file(s) (created by tf).
    Note: this file contains absolute paths, be careful when moving files around;
    'model.ckpt-*' : file(s) with model definition (created by tf)

Generate a story

Run the multipara.py --prime "Your seed text." -n 200 script to generate a multiple paragraph story.

Parameters in multipara.py

  • -n default=100
    Number of words to sample.
  • --prime default=_'There has never been a better time to be alive, thought Watson.'_
    The prime text, i.e. the seed text.

Note

If you're using Google Collab, add an exclamation mark before running a python command.

!python train.py

Additional Information