TV Script Generation

In this project, you'll generate your own Simpsons TV scripts using RNNs. You'll be using part of the Simpsons dataset of scripts from 27 seasons. The Neural Network you'll build will generate a new TV script for a scene at Moe's Tavern.

Get the Data

The data is already provided for you. You'll be using a subset of the original dataset. It consists of only the scenes in Moe's Tavern. This doesn't include other versions of the tavern, like "Moe's Cavern", "Flaming Moe's", "Uncle Moe's Family Feed-Bag", etc..

In [64]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
import helper

data_dir = './data/simpsons/moes_tavern_lines.txt'
text = helper.load_data(data_dir)
# Ignore notice, since we don't use it for analysing the data
text = text[81:]

Explore the Data

Play around with view_sentence_range to view different parts of the data.

In [14]:
view_sentence_range = (0, 10)

"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
import numpy as np

print('Dataset Stats')
print('Roughly the number of unique words: {}'.format(len({word: None for word in text.split()})))
scenes = text.split('\n\n')
print('Number of scenes: {}'.format(len(scenes)))
sentence_count_scene = [scene.count('\n') for scene in scenes]
print('Average number of sentences in each scene: {}'.format(np.average(sentence_count_scene)))

sentences = [sentence for scene in scenes for sentence in scene.split('\n')]
print('Number of lines: {}'.format(len(sentences)))
word_count_sentence = [len(sentence.split()) for sentence in sentences]
print('Average number of words in each line: {}'.format(np.average(word_count_sentence)))

print()
print('The sentences {} to {}:'.format(*view_sentence_range))
print('\n'.join(text.split('\n')[view_sentence_range[0]:view_sentence_range[1]]))
Dataset Stats
Roughly the number of unique words: 11492
Number of scenes: 262
Average number of sentences in each scene: 15.248091603053435
Number of lines: 4257
Average number of words in each line: 11.50434578341555

The sentences 0 to 10:
Moe_Szyslak: (INTO PHONE) Moe's Tavern. Where the elite meet to drink.
Bart_Simpson: Eh, yeah, hello, is Mike there? Last name, Rotch.
Moe_Szyslak: (INTO PHONE) Hold on, I'll check. (TO BARFLIES) Mike Rotch. Mike Rotch. Hey, has anybody seen Mike Rotch, lately?
Moe_Szyslak: (INTO PHONE) Listen you little puke. One of these days I'm gonna catch you, and I'm gonna carve my name on your back with an ice pick.
Moe_Szyslak: What's the matter Homer? You're not your normal effervescent self.
Homer_Simpson: I got my problems, Moe. Give me another one.
Moe_Szyslak: Homer, hey, you should not drink to forget your problems.
Barney_Gumble: Yeah, you should only drink to enhance your social skills.


Implement Preprocessing Functions

The first thing to do to any dataset is preprocessing. Implement the following preprocessing functions below:

  • Lookup Table
  • Tokenize Punctuation

Lookup Table

To create a word embedding, you first need to transform the words to ids. In this function, create two dictionaries:

  • Dictionary to go from the words to an id, we'll call vocab_to_int
  • Dictionary to go from the id to word, we'll call int_to_vocab

Return these dictionaries in the following tuple (vocab_to_int, int_to_vocab)

In [65]:
import numpy as np
import problem_unittests as tests

def create_lookup_tables(text):
    """
    Create lookup tables for vocabulary
    :param text: The text of tv scripts split into words
    :return: A tuple of dicts (vocab_to_int, int_to_vocab)
    """
    vocab = set(text)
    
    vocab_to_int = {word: index for index, word in enumerate(vocab)}
    int_to_vocab = {index: word for (word, index) in vocab_to_int.items()}
    
    return vocab_to_int, int_to_vocab


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_create_lookup_tables(create_lookup_tables)
Tests Passed

Tokenize Punctuation

We'll be splitting the script into a word array using spaces as delimiters. However, punctuations like periods and exclamation marks make it hard for the neural network to distinguish between the word "bye" and "bye!".

Implement the function token_lookup to return a dict that will be used to tokenize symbols like "!" into "||Exclamation_Mark||". Create a dictionary for the following symbols where the symbol is the key and value is the token:

  • Period ( . )
  • Comma ( , )
  • Quotation Mark ( " )
  • Semicolon ( ; )
  • Exclamation mark ( ! )
  • Question mark ( ? )
  • Left Parentheses ( ( )
  • Right Parentheses ( ) )
  • Dash ( -- )
  • Return ( \n )

This dictionary will be used to token the symbols and add the delimiter (space) around it. This separates the symbols as it's own word, making it easier for the neural network to predict on the next word. Make sure you don't use a token that could be confused as a word. Instead of using the token "dash", try using something like "||dash||".

In [16]:
def token_lookup():
    """
    Generate a dict to turn punctuation into a token.
    :return: Tokenize dictionary where the key is the punctuation and the value is the token
    """
    
    return {
        '.': '||period||',
        ',': '||comma||',
        '"': '||quotation_mark||',
        ';': '||semicolon||',
        '!': '||exclamation_mark||',
        '?': '||question_mark||',
        '(': '||left_parentheses',
        ')': '||right_parentheses',
        '--': '||dash||',
        '\n': '||return||'
    }

"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_tokenize(token_lookup)
Tests Passed

Preprocess all the data and save it

Running the code cell below will preprocess all the data and save it to file.

In [17]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
# Preprocess Training, Validation, and Testing Data
helper.preprocess_and_save_data(data_dir, token_lookup, create_lookup_tables)

Check Point

This is your first checkpoint. If you ever decide to come back to this notebook or have to restart the notebook, you can start from here. The preprocessed data has been saved to disk.

In [1]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
import helper
import numpy as np
import problem_unittests as tests

int_text, vocab_to_int, int_to_vocab, token_dict = helper.load_preprocess()

Extra hyper parameters

In [177]:
from collections import namedtuple

hyper_params = (('embedding_size', 128),
                ('lstm_layers', 2),
                ('keep_prob', 0.7)
               )




Hyper = namedtuple('Hyper', map(lambda x: x[0], hyper_params))
HYPER = Hyper(*list(map(lambda x: x[1], hyper_params)))

Build the Neural Network

You'll build the components necessary to build a RNN by implementing the following functions below:

  • get_inputs
  • get_init_cell
  • get_embed
  • build_rnn
  • build_nn
  • get_batches

Check the Version of TensorFlow and Access to GPU

In [3]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
from distutils.version import LooseVersion
import warnings
import tensorflow as tf

# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))

# Check for a GPU
if not tf.test.gpu_device_name():
    warnings.warn('No GPU found. Please use a GPU to train your neural network.')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
TensorFlow Version: 1.0.0
Default GPU Device: /gpu:0

Input

Implement the get_inputs() function to create TF Placeholders for the Neural Network. It should create the following placeholders:

  • Input text placeholder named "input" using the TF Placeholder name parameter.
  • Targets placeholder
  • Learning Rate placeholder

Return the placeholders in the following the tuple (Input, Targets, LearingRate)

In [225]:
def get_inputs():
    """
    Create TF Placeholders for input, targets, and learning rate.
    :return: Tuple (input, targets, learning rate)
    """
    
    # We use shape [None, None] to feed any batch size and any sequence length
    input_placeholder = tf.placeholder(tf.int64, [None, None],name='input')
    
    # Targets are [batch_size, seq_length]
    targets_placeholder = tf.placeholder(tf.int64, [None, None], name='targets') 
    
    
    learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')
    return input_placeholder, targets_placeholder, learning_rate_placeholder


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_get_inputs(get_inputs)
Tests Passed

Build RNN Cell and Initialize

Stack one or more BasicLSTMCells in a MultiRNNCell.

  • The Rnn size should be set using rnn_size
  • Initalize Cell State using the MultiRNNCell's zero_state() function
    • Apply the name "initial_state" to the initial state using tf.identity()

Return the cell and initial state in the following tuple (Cell, InitialState)

In [227]:
def get_init_cell(batch_size, rnn_size):
    """
    Create an RNN Cell and initialize it.
    :param batch_size: Size of batches
    :param rnn_size: Size of RNNs
    :return: Tuple (cell, initialize state)
    """
    with tf.name_scope('RNN_layers'):
        lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)

        # add a dropout wrapper
        drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=HYPER.keep_prob)

        #cell = tf.contrib.rnn.MultiRNNCell([drop] * HYPER.lstm_layers)

        cell = tf.contrib.rnn.MultiRNNCell([lstm] * HYPER.lstm_layers)
    
   
    _initial_state = cell.zero_state(batch_size, tf.float32)
    initial_state = tf.identity(_initial_state, name='initial_state')
    
    return cell, initial_state


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_get_init_cell(get_init_cell)
Tests Passed

Word Embedding

Apply embedding to input_data using TensorFlow. Return the embedded sequence.

In [207]:
def get_embed(input_data, vocab_size, embed_dim):
    """
    Create embedding for <input_data>.
    :param input_data: TF placeholder for text input.
    :param vocab_size: Number of words in vocabulary.
    :param embed_dim: Number of embedding dimensions
    :return: Embedded input.
    """
    with tf.name_scope('Embedding'):
        embeddings = tf.Variable(
            tf.random_uniform([vocab_size, embed_dim], -1.0, 1.0)
        )

        embed = tf.nn.embedding_lookup(embeddings, input_data)
    
    return embed


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_get_embed(get_embed)
Tests Passed

Build RNN

You created a RNN Cell in the get_init_cell() function. Time to use the cell to create a RNN.

Return the outputs and final_state state in the following tuple (Outputs, FinalState)

In [228]:
def build_rnn(cell, inputs):
    """
    Create a RNN using a RNN Cell
    :param cell: RNN Cell
    :param inputs: Input text data
    :return: Tuple (Outputs, Final State)
    """
    ## NOTES
    # dynamic rnn automatically takes the seq size in dim=1 [batch_size, max_time, ...] time_major==false (default)
    with tf.name_scope('RNN_output'):
        outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)
    
    final_state = tf.identity(final_state, name='final_state')
    
    
    return outputs, final_state


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_build_rnn(build_rnn)
Tests Passed

Build the Neural Network

Apply the functions you implemented above to:

  • Apply embedding to input_data using your get_embed(input_data, vocab_size, embed_dim) function.
  • Build RNN using cell and your build_rnn(cell, inputs) function.
  • Apply a fully connected layer with a linear activation and vocab_size as the number of outputs.

Return the logits and final state in the following tuple (Logits, FinalState)

In [231]:
def build_nn(cell, rnn_size, input_data, vocab_size):
    """
    Build part of the neural network
    :param cell: RNN cell
    :param rnn_size: Size of rnns
    :param input_data: Input data
    :param vocab_size: Vocabulary size
    :return: Tuple (Logits, FinalState)
    """
    
    num_outputs = vocab_size
    
    
    ## Not sure why the unit test was made without taking into 
    # account we are handling dynamic tensor shape that we need to infer
    # at runtime, so I made an if statement just to pass the test case
    #
    # Some references: https://goo.gl/vD3egn
    #                  https://goo.gl/E8vT2M 
    
    if input_data.get_shape().as_list()[1] is not None:
        batch_size = input_data.get_shape().as_list()[0]
        seq_len = input_data.get_shape().as_list()[1]
    
    # Infer dynamic tensor shape of input
    else:
        input_dims = tf.shape(input_data)
        batch_size = input_dims[0]
        seq_len = input_dims[1]

    

    
    embed = get_embed(input_data, vocab_size, HYPER.embedding_size)
    
    
    ## NOTES
    # dynamic rnn automatically takes the seq size in dim=1 [batch_size, max_time, ...] see: time_major==false (default)
    
    ## Output shape
    ## [batch_size, time_step, rnn_size]
    raw_rnn_outputs, final_state = build_rnn(cell, embed)
    
    
    # Put outputs in rows
    # make the output into [batch_size*time_step, rnn_size] for easy matmul
    with tf.name_scope('sequence_reshape'):
        outputs = tf.reshape(raw_rnn_outputs, [-1, rnn_size], name='rnn_output')
    
    
    # Question, why are we using linear activation and not softmax ?
    # My Guess: because seq2seq.sequence_loss has an efficient way to calculate the loss directly from logits 
    with tf.name_scope('logits'):
        
        linear_w = tf.Variable(tf.truncated_normal((rnn_size, num_outputs), stddev=0.05), name='linear_w')
        linear_b = tf.Variable(tf.zeros(num_outputs), name='linear_b')

        logits = tf.matmul(outputs, linear_w) + linear_b
    
    
    
    # Reshape the logits back into the original input shape -> [batch_size, seq_len, num_classes]
    # We do this beceause the loss function seq2seq.sequence_loss takes as logits a shape of [batch_size,seq_len,num_decoded_symbols]
    with tf.name_scope('logits_reshape_to_loss'):
        logits = tf.reshape(logits, [batch_size, seq_len, num_outputs], name='logits')
        print('logits after reshape: ', logits)
    
    return logits, final_state


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_build_nn(build_nn)
logits after reshape:  Tensor("logits_reshape_to_loss/logits:0", shape=(128, 5, 27), dtype=float32)
Tests Passed

Batches

Implement get_batches to create batches of input and targets using int_text. The batches should be a Numpy array with the shape (number of batches, 2, batch size, sequence length). Each batch contains two elements:

  • The first element is a single batch of input with the shape [batch size, sequence length]
  • The second element is a single batch of targets with the shape [batch size, sequence length]

If you can't fill the last batch with enough data, drop the last batch.

For exmple, get_batches([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 2, 3) would return a Numpy array of the following:

[
  # First Batch
  [
    # Batch of Input
    [[ 1  2  3], [ 7  8  9]],
    # Batch of targets
    [[ 2  3  4], [ 8  9 10]]
  ],

  # Second Batch
  [
    # Batch of Input
    [[ 4  5  6], [10 11 12]],
    # Batch of targets
    [[ 5  6  7], [11 12 13]]
  ]
]
In [233]:
def get_batches(int_text, batch_size, seq_length):
    """
    Return batches of input and target
    :param int_text: Text with the words replaced by their ids
    :param batch_size: The size of batch
    :param seq_length: The length of sequence
    :return: Batches as a Numpy array
    """
    
    slice_size = batch_size * seq_length
    n_batches = int(len(int_text)/slice_size)
    
    # input part
    _inputs = np.array(int_text[:n_batches*slice_size])
    
    # target part
    _targets = np.array(int_text[1:n_batches*slice_size + 1])
    

    # Go through all inputs, targets and split them into batch_size*seq_len list of items
    # [batch, batch, ...]
    inputs, targets = np.split(_inputs, n_batches), np.split(_targets, n_batches)
    
    # concat inputs and targets
    batches = np.c_[inputs, targets]
    #print(batches.shape)
    
    # Reshape into final batches output
    batches = batches.reshape((-1, 2, batch_size, seq_length))

    #print(batches[0][1])

    
    return batches


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_get_batches(get_batches)
Tests Passed

Neural Network Training

Hyperparameters

Tune the following parameters:

  • Set num_epochs to the number of epochs.
  • Set batch_size to the batch size.
  • Set rnn_size to the size of the RNNs.
  • Set seq_length to the length of sequence.
  • Set learning_rate to the learning rate.
  • Set show_every_n_batches to the number of batches the neural network should print progress.
In [234]:
# Number of Epochs
num_epochs = 1000
# Batch Size
batch_size = 128
# RNN Size
rnn_size = 70
# Sequence Length
seq_length = 100
# Learning Rate
learning_rate = 1e-3
# Show stats for every n number of batches
show_every_n_batches = 10

"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
save_dir = './save'

Build the Graph

Build the graph using the neural network you implemented.

In [235]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
from tensorflow.contrib import seq2seq

train_graph = tf.Graph()
with train_graph.as_default():
    vocab_size = len(int_to_vocab)
    input_text, targets, lr = get_inputs()
    input_data_shape = tf.shape(input_text)
    cell, initial_state = get_init_cell(input_data_shape[0], rnn_size)
    logits, final_state = build_nn(cell, rnn_size, input_text, vocab_size)

    # Probabilities for generating words
    probs = tf.nn.softmax(logits, name='probs')

    # Loss function
    cost = seq2seq.sequence_loss(
        logits,
        targets,
        tf.ones([input_data_shape[0], input_data_shape[1]]))

    # Optimizer
    optimizer = tf.train.AdamOptimizer(lr)

    # Gradient Clipping
    gradients = optimizer.compute_gradients(cost)
    capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients]
    train_op = optimizer.apply_gradients(capped_gradients)
logits after reshape:  Tensor("logits_reshape_to_loss/logits:0", shape=(?, ?, 6779), dtype=float32)
In [238]:
# write out the graph for tensorboard

with tf.Session(graph=train_graph) as sess:
    file_writer = tf.summary.FileWriter('./logs/1', sess.graph)

Train

Train the neural network on the preprocessed data. If you have a hard time getting a good loss, check the forms to see if anyone is having the same problem.

In [197]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
batches = get_batches(int_text, batch_size, seq_length)

with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())

    for epoch_i in range(num_epochs):
        state = sess.run(initial_state, {input_text: batches[0][0]})

        for batch_i, (x, y) in enumerate(batches):
            feed = {
                input_text: x,
                targets: y,
                initial_state: state,
                lr: learning_rate}
            train_loss, state, _ = sess.run([cost, final_state, train_op], feed)

            # Show every <show_every_n_batches> batches
            if (epoch_i * len(batches) + batch_i) % show_every_n_batches == 0:
                print('Epoch {:>3} Batch {:>4}/{}   train_loss = {:.3f}'.format(
                    epoch_i,
                    batch_i,
                    len(batches),
                    train_loss))

    # Save Model
    saver = tf.train.Saver()
    saver.save(sess, save_dir)
    print('Model Trained and Saved')
5
Epoch   0 Batch    0/5   train_loss = 8.825
Epoch   2 Batch    0/5   train_loss = 6.441
Epoch   4 Batch    0/5   train_loss = 6.023
Epoch   6 Batch    0/5   train_loss = 5.927
Epoch   8 Batch    0/5   train_loss = 5.903
Epoch  10 Batch    0/5   train_loss = 5.883
Epoch  12 Batch    0/5   train_loss = 5.874
Epoch  14 Batch    0/5   train_loss = 5.858
Epoch  16 Batch    0/5   train_loss = 5.833
Epoch  18 Batch    0/5   train_loss = 5.794
Epoch  20 Batch    0/5   train_loss = 5.739
Epoch  22 Batch    0/5   train_loss = 5.682
Epoch  24 Batch    0/5   train_loss = 5.626
Epoch  26 Batch    0/5   train_loss = 5.572
Epoch  28 Batch    0/5   train_loss = 5.521
Epoch  30 Batch    0/5   train_loss = 5.471
Epoch  32 Batch    0/5   train_loss = 5.421
Epoch  34 Batch    0/5   train_loss = 5.365
Epoch  36 Batch    0/5   train_loss = 5.304
Epoch  38 Batch    0/5   train_loss = 5.244
Epoch  40 Batch    0/5   train_loss = 5.185
Epoch  42 Batch    0/5   train_loss = 5.124
Epoch  44 Batch    0/5   train_loss = 5.063
Epoch  46 Batch    0/5   train_loss = 5.003
Epoch  48 Batch    0/5   train_loss = 4.945
Epoch  50 Batch    0/5   train_loss = 4.891
Epoch  52 Batch    0/5   train_loss = 4.841
Epoch  54 Batch    0/5   train_loss = 4.794
Epoch  56 Batch    0/5   train_loss = 4.751
Epoch  58 Batch    0/5   train_loss = 4.710
Epoch  60 Batch    0/5   train_loss = 4.669
Epoch  62 Batch    0/5   train_loss = 4.638
Epoch  64 Batch    0/5   train_loss = 4.638
Epoch  66 Batch    0/5   train_loss = 4.589
Epoch  68 Batch    0/5   train_loss = 4.537
Epoch  70 Batch    0/5   train_loss = 4.501
Epoch  72 Batch    0/5   train_loss = 4.469
Epoch  74 Batch    0/5   train_loss = 4.436
Epoch  76 Batch    0/5   train_loss = 4.405
Epoch  78 Batch    0/5   train_loss = 4.375
Epoch  80 Batch    0/5   train_loss = 4.344
Epoch  82 Batch    0/5   train_loss = 4.363
Epoch  84 Batch    0/5   train_loss = 4.311
Epoch  86 Batch    0/5   train_loss = 4.274
Epoch  88 Batch    0/5   train_loss = 4.240
Epoch  90 Batch    0/5   train_loss = 4.211
Epoch  92 Batch    0/5   train_loss = 4.182
Epoch  94 Batch    0/5   train_loss = 4.155
Epoch  96 Batch    0/5   train_loss = 4.135
Epoch  98 Batch    0/5   train_loss = 4.107
Epoch 100 Batch    0/5   train_loss = 4.093
Epoch 102 Batch    0/5   train_loss = 4.053
Epoch 104 Batch    0/5   train_loss = 4.030
Epoch 106 Batch    0/5   train_loss = 4.002
Epoch 108 Batch    0/5   train_loss = 3.978
Epoch 110 Batch    0/5   train_loss = 3.951
Epoch 112 Batch    0/5   train_loss = 3.928
Epoch 114 Batch    0/5   train_loss = 3.902
Epoch 116 Batch    0/5   train_loss = 3.884
Epoch 118 Batch    0/5   train_loss = 3.862
Epoch 120 Batch    0/5   train_loss = 3.840
Epoch 122 Batch    0/5   train_loss = 3.814
Epoch 124 Batch    0/5   train_loss = 3.803
Epoch 126 Batch    0/5   train_loss = 3.775
Epoch 128 Batch    0/5   train_loss = 3.738
Epoch 130 Batch    0/5   train_loss = 3.714
Epoch 132 Batch    0/5   train_loss = 3.690
Epoch 134 Batch    0/5   train_loss = 3.665
Epoch 136 Batch    0/5   train_loss = 3.642
Epoch 138 Batch    0/5   train_loss = 3.619
Epoch 140 Batch    0/5   train_loss = 3.596
Epoch 142 Batch    0/5   train_loss = 3.577
Epoch 144 Batch    0/5   train_loss = 3.588
Epoch 146 Batch    0/5   train_loss = 3.561
Epoch 148 Batch    0/5   train_loss = 3.537
Epoch 150 Batch    0/5   train_loss = 3.494
Epoch 152 Batch    0/5   train_loss = 3.475
Epoch 154 Batch    0/5   train_loss = 3.444
Epoch 156 Batch    0/5   train_loss = 3.431
Epoch 158 Batch    0/5   train_loss = 3.403
Epoch 160 Batch    0/5   train_loss = 3.393
Epoch 162 Batch    0/5   train_loss = 3.371
Epoch 164 Batch    0/5   train_loss = 3.352
Epoch 166 Batch    0/5   train_loss = 3.323
Epoch 168 Batch    0/5   train_loss = 3.328
Epoch 170 Batch    0/5   train_loss = 3.281
Epoch 172 Batch    0/5   train_loss = 3.261
Epoch 174 Batch    0/5   train_loss = 3.238
Epoch 176 Batch    0/5   train_loss = 3.216
Epoch 178 Batch    0/5   train_loss = 3.197
Epoch 180 Batch    0/5   train_loss = 3.172
Epoch 182 Batch    0/5   train_loss = 3.169
Epoch 184 Batch    0/5   train_loss = 3.140
Epoch 186 Batch    0/5   train_loss = 3.136
Epoch 188 Batch    0/5   train_loss = 3.145
Epoch 190 Batch    0/5   train_loss = 3.106
Epoch 192 Batch    0/5   train_loss = 3.069
Epoch 194 Batch    0/5   train_loss = 3.038
Epoch 196 Batch    0/5   train_loss = 3.019
Epoch 198 Batch    0/5   train_loss = 2.995
Epoch 200 Batch    0/5   train_loss = 2.979
Epoch 202 Batch    0/5   train_loss = 2.960
Epoch 204 Batch    0/5   train_loss = 2.943
Epoch 206 Batch    0/5   train_loss = 2.963
Epoch 208 Batch    0/5   train_loss = 2.917
Epoch 210 Batch    0/5   train_loss = 2.898
Epoch 212 Batch    0/5   train_loss = 2.867
Epoch 214 Batch    0/5   train_loss = 2.863
Epoch 216 Batch    0/5   train_loss = 2.834
Epoch 218 Batch    0/5   train_loss = 2.809
Epoch 220 Batch    0/5   train_loss = 2.797
Epoch 222 Batch    0/5   train_loss = 2.774
Epoch 224 Batch    0/5   train_loss = 2.759
Epoch 226 Batch    0/5   train_loss = 2.732
Epoch 228 Batch    0/5   train_loss = 2.742
Epoch 230 Batch    0/5   train_loss = 2.704
Epoch 232 Batch    0/5   train_loss = 2.703
Epoch 234 Batch    0/5   train_loss = 2.663
Epoch 236 Batch    0/5   train_loss = 2.672
Epoch 238 Batch    0/5   train_loss = 2.638
Epoch 240 Batch    0/5   train_loss = 2.620
Epoch 242 Batch    0/5   train_loss = 2.595
Epoch 244 Batch    0/5   train_loss = 2.585
Epoch 246 Batch    0/5   train_loss = 2.563
Epoch 248 Batch    0/5   train_loss = 2.539
Epoch 250 Batch    0/5   train_loss = 2.534
Epoch 252 Batch    0/5   train_loss = 2.517
Epoch 254 Batch    0/5   train_loss = 2.497
Epoch 256 Batch    0/5   train_loss = 2.475
Epoch 258 Batch    0/5   train_loss = 2.463
Epoch 260 Batch    0/5   train_loss = 2.478
Epoch 262 Batch    0/5   train_loss = 2.450
Epoch 264 Batch    0/5   train_loss = 2.436
Epoch 266 Batch    0/5   train_loss = 2.417
Epoch 268 Batch    0/5   train_loss = 2.384
Epoch 270 Batch    0/5   train_loss = 2.363
Epoch 272 Batch    0/5   train_loss = 2.340
Epoch 274 Batch    0/5   train_loss = 2.323
Epoch 276 Batch    0/5   train_loss = 2.314
Epoch 278 Batch    0/5   train_loss = 2.302
Epoch 280 Batch    0/5   train_loss = 2.300
Epoch 282 Batch    0/5   train_loss = 2.300
Epoch 284 Batch    0/5   train_loss = 2.283
Epoch 286 Batch    0/5   train_loss = 2.246
Epoch 288 Batch    0/5   train_loss = 2.246
Epoch 290 Batch    0/5   train_loss = 2.210
Epoch 292 Batch    0/5   train_loss = 2.203
Epoch 294 Batch    0/5   train_loss = 2.185
Epoch 296 Batch    0/5   train_loss = 2.170
Epoch 298 Batch    0/5   train_loss = 2.150
Epoch 300 Batch    0/5   train_loss = 2.130
Epoch 302 Batch    0/5   train_loss = 2.132
Epoch 304 Batch    0/5   train_loss = 2.113
Epoch 306 Batch    0/5   train_loss = 2.083
Epoch 308 Batch    0/5   train_loss = 2.073
Epoch 310 Batch    0/5   train_loss = 2.060
Epoch 312 Batch    0/5   train_loss = 2.072
Epoch 314 Batch    0/5   train_loss = 2.081
Epoch 316 Batch    0/5   train_loss = 2.031
Epoch 318 Batch    0/5   train_loss = 2.007
Epoch 320 Batch    0/5   train_loss = 2.001
Epoch 322 Batch    0/5   train_loss = 1.987
Epoch 324 Batch    0/5   train_loss = 1.978
Epoch 326 Batch    0/5   train_loss = 1.963
Epoch 328 Batch    0/5   train_loss = 1.952
Epoch 330 Batch    0/5   train_loss = 1.932
Epoch 332 Batch    0/5   train_loss = 1.918
Epoch 334 Batch    0/5   train_loss = 1.898
Epoch 336 Batch    0/5   train_loss = 1.885
Epoch 338 Batch    0/5   train_loss = 1.872
Epoch 340 Batch    0/5   train_loss = 1.864
Epoch 342 Batch    0/5   train_loss = 1.867
Epoch 344 Batch    0/5   train_loss = 1.848
Epoch 346 Batch    0/5   train_loss = 1.821
Epoch 348 Batch    0/5   train_loss = 1.814
Epoch 350 Batch    0/5   train_loss = 1.788
Epoch 352 Batch    0/5   train_loss = 1.806
Epoch 354 Batch    0/5   train_loss = 1.790
Epoch 356 Batch    0/5   train_loss = 1.761
Epoch 358 Batch    0/5   train_loss = 1.745
Epoch 360 Batch    0/5   train_loss = 1.735
Epoch 362 Batch    0/5   train_loss = 1.718
Epoch 364 Batch    0/5   train_loss = 1.747
Epoch 366 Batch    0/5   train_loss = 1.726
Epoch 368 Batch    0/5   train_loss = 1.753
Epoch 370 Batch    0/5   train_loss = 1.703
Epoch 372 Batch    0/5   train_loss = 1.662
Epoch 374 Batch    0/5   train_loss = 1.643
Epoch 376 Batch    0/5   train_loss = 1.624
Epoch 378 Batch    0/5   train_loss = 1.617
Epoch 380 Batch    0/5   train_loss = 1.598
Epoch 382 Batch    0/5   train_loss = 1.613
Epoch 384 Batch    0/5   train_loss = 1.601
Epoch 386 Batch    0/5   train_loss = 1.584
Epoch 388 Batch    0/5   train_loss = 1.569
Epoch 390 Batch    0/5   train_loss = 1.557
Epoch 392 Batch    0/5   train_loss = 1.534
Epoch 394 Batch    0/5   train_loss = 1.534
Epoch 396 Batch    0/5   train_loss = 1.520
Epoch 398 Batch    0/5   train_loss = 1.547
Epoch 400 Batch    0/5   train_loss = 1.545
Epoch 402 Batch    0/5   train_loss = 1.521
Epoch 404 Batch    0/5   train_loss = 1.486
Epoch 406 Batch    0/5   train_loss = 1.469
Epoch 408 Batch    0/5   train_loss = 1.458
Epoch 410 Batch    0/5   train_loss = 1.442
Epoch 412 Batch    0/5   train_loss = 1.431
Epoch 414 Batch    0/5   train_loss = 1.410
Epoch 416 Batch    0/5   train_loss = 1.411
Epoch 418 Batch    0/5   train_loss = 1.412
Epoch 420 Batch    0/5   train_loss = 1.398
Epoch 422 Batch    0/5   train_loss = 1.417
Epoch 424 Batch    0/5   train_loss = 1.381
Epoch 426 Batch    0/5   train_loss = 1.355
Epoch 428 Batch    0/5   train_loss = 1.354
Epoch 430 Batch    0/5   train_loss = 1.338
Epoch 432 Batch    0/5   train_loss = 1.321
Epoch 434 Batch    0/5   train_loss = 1.326
Epoch 436 Batch    0/5   train_loss = 1.324
Epoch 438 Batch    0/5   train_loss = 1.314
Epoch 440 Batch    0/5   train_loss = 1.292
Epoch 442 Batch    0/5   train_loss = 1.279
Epoch 444 Batch    0/5   train_loss = 1.259
Epoch 446 Batch    0/5   train_loss = 1.283
Epoch 448 Batch    0/5   train_loss = 1.274
Epoch 450 Batch    0/5   train_loss = 1.251
Epoch 452 Batch    0/5   train_loss = 1.279
Epoch 454 Batch    0/5   train_loss = 1.249
Epoch 456 Batch    0/5   train_loss = 1.214
Epoch 458 Batch    0/5   train_loss = 1.196
Epoch 460 Batch    0/5   train_loss = 1.185
Epoch 462 Batch    0/5   train_loss = 1.174
Epoch 464 Batch    0/5   train_loss = 1.158
Epoch 466 Batch    0/5   train_loss = 1.195
Epoch 468 Batch    0/5   train_loss = 1.158
Epoch 470 Batch    0/5   train_loss = 1.145
Epoch 472 Batch    0/5   train_loss = 1.160
Epoch 474 Batch    0/5   train_loss = 1.123
Epoch 476 Batch    0/5   train_loss = 1.118
Epoch 478 Batch    0/5   train_loss = 1.103
Epoch 480 Batch    0/5   train_loss = 1.088
Epoch 482 Batch    0/5   train_loss = 1.089
Epoch 484 Batch    0/5   train_loss = 1.094
Epoch 486 Batch    0/5   train_loss = 1.092
Epoch 488 Batch    0/5   train_loss = 1.106
Epoch 490 Batch    0/5   train_loss = 1.053
Epoch 492 Batch    0/5   train_loss = 1.052
Epoch 494 Batch    0/5   train_loss = 1.046
Epoch 496 Batch    0/5   train_loss = 1.030
Epoch 498 Batch    0/5   train_loss = 1.021
Epoch 500 Batch    0/5   train_loss = 1.020
Epoch 502 Batch    0/5   train_loss = 1.046
Epoch 504 Batch    0/5   train_loss = 1.040
Epoch 506 Batch    0/5   train_loss = 1.026
Epoch 508 Batch    0/5   train_loss = 0.982
Epoch 510 Batch    0/5   train_loss = 0.969
Epoch 512 Batch    0/5   train_loss = 0.962
Epoch 514 Batch    0/5   train_loss = 0.946
Epoch 516 Batch    0/5   train_loss = 0.941
Epoch 518 Batch    0/5   train_loss = 0.951
Epoch 520 Batch    0/5   train_loss = 0.945
Epoch 522 Batch    0/5   train_loss = 0.952
Epoch 524 Batch    0/5   train_loss = 0.931
Epoch 526 Batch    0/5   train_loss = 0.905
Epoch 528 Batch    0/5   train_loss = 0.893
Epoch 530 Batch    0/5   train_loss = 0.881
Epoch 532 Batch    0/5   train_loss = 0.882
Epoch 534 Batch    0/5   train_loss = 0.871
Epoch 536 Batch    0/5   train_loss = 0.904
Epoch 538 Batch    0/5   train_loss = 0.893
Epoch 540 Batch    0/5   train_loss = 0.884
Epoch 542 Batch    0/5   train_loss = 0.864
Epoch 544 Batch    0/5   train_loss = 0.854
Epoch 546 Batch    0/5   train_loss = 0.854
Epoch 548 Batch    0/5   train_loss = 0.836
Epoch 550 Batch    0/5   train_loss = 0.816
Epoch 552 Batch    0/5   train_loss = 0.829
Epoch 554 Batch    0/5   train_loss = 0.813
Epoch 556 Batch    0/5   train_loss = 0.798
Epoch 558 Batch    0/5   train_loss = 0.808
Epoch 560 Batch    0/5   train_loss = 0.789
Epoch 562 Batch    0/5   train_loss = 0.791
Epoch 564 Batch    0/5   train_loss = 0.779
Epoch 566 Batch    0/5   train_loss = 0.765
Epoch 568 Batch    0/5   train_loss = 0.746
Epoch 570 Batch    0/5   train_loss = 0.746
Epoch 572 Batch    0/5   train_loss = 0.733
Epoch 574 Batch    0/5   train_loss = 0.733
Epoch 576 Batch    0/5   train_loss = 0.752
Epoch 578 Batch    0/5   train_loss = 0.727
Epoch 580 Batch    0/5   train_loss = 0.712
Epoch 582 Batch    0/5   train_loss = 0.711
Epoch 584 Batch    0/5   train_loss = 0.708
Epoch 586 Batch    0/5   train_loss = 0.695
Epoch 588 Batch    0/5   train_loss = 0.699
Epoch 590 Batch    0/5   train_loss = 0.688
Epoch 592 Batch    0/5   train_loss = 0.682
Epoch 594 Batch    0/5   train_loss = 0.703
Epoch 596 Batch    0/5   train_loss = 0.681
Epoch 598 Batch    0/5   train_loss = 0.672
Epoch 600 Batch    0/5   train_loss = 0.678
Epoch 602 Batch    0/5   train_loss = 0.657
Epoch 604 Batch    0/5   train_loss = 0.652
Epoch 606 Batch    0/5   train_loss = 0.627
Epoch 608 Batch    0/5   train_loss = 0.623
Epoch 610 Batch    0/5   train_loss = 0.633
Epoch 612 Batch    0/5   train_loss = 0.608
Epoch 614 Batch    0/5   train_loss = 0.614
Epoch 616 Batch    0/5   train_loss = 0.620
Epoch 618 Batch    0/5   train_loss = 0.610
Epoch 620 Batch    0/5   train_loss = 0.596
Epoch 622 Batch    0/5   train_loss = 0.596
Epoch 624 Batch    0/5   train_loss = 0.605
Epoch 626 Batch    0/5   train_loss = 0.574
Epoch 628 Batch    0/5   train_loss = 0.581
Epoch 630 Batch    0/5   train_loss = 0.571
Epoch 632 Batch    0/5   train_loss = 0.563
Epoch 634 Batch    0/5   train_loss = 0.582
Epoch 636 Batch    0/5   train_loss = 0.579
Epoch 638 Batch    0/5   train_loss = 0.562
Epoch 640 Batch    0/5   train_loss = 0.549
Epoch 642 Batch    0/5   train_loss = 0.540
Epoch 644 Batch    0/5   train_loss = 0.520
Epoch 646 Batch    0/5   train_loss = 0.515
Epoch 648 Batch    0/5   train_loss = 0.509
Epoch 650 Batch    0/5   train_loss = 0.509
Epoch 652 Batch    0/5   train_loss = 0.527
Epoch 654 Batch    0/5   train_loss = 0.524
Epoch 656 Batch    0/5   train_loss = 0.509
Epoch 658 Batch    0/5   train_loss = 0.523
Epoch 660 Batch    0/5   train_loss = 0.502
Epoch 662 Batch    0/5   train_loss = 0.477
Epoch 664 Batch    0/5   train_loss = 0.473
Epoch 666 Batch    0/5   train_loss = 0.463
Epoch 668 Batch    0/5   train_loss = 0.457
Epoch 670 Batch    0/5   train_loss = 0.455
Epoch 672 Batch    0/5   train_loss = 0.459
Epoch 674 Batch    0/5   train_loss = 0.475
Epoch 676 Batch    0/5   train_loss = 0.471
Epoch 678 Batch    0/5   train_loss = 0.455
Epoch 680 Batch    0/5   train_loss = 0.443
Epoch 682 Batch    0/5   train_loss = 0.456
Epoch 684 Batch    0/5   train_loss = 0.440
Epoch 686 Batch    0/5   train_loss = 0.421
Epoch 688 Batch    0/5   train_loss = 0.413
Epoch 690 Batch    0/5   train_loss = 0.405
Epoch 692 Batch    0/5   train_loss = 0.401
Epoch 694 Batch    0/5   train_loss = 0.404
Epoch 696 Batch    0/5   train_loss = 0.400
Epoch 698 Batch    0/5   train_loss = 0.428
Epoch 700 Batch    0/5   train_loss = 0.451
Epoch 702 Batch    0/5   train_loss = 0.426
Epoch 704 Batch    0/5   train_loss = 0.410
Epoch 706 Batch    0/5   train_loss = 0.422
Epoch 708 Batch    0/5   train_loss = 0.398
Epoch 710 Batch    0/5   train_loss = 0.377
Epoch 712 Batch    0/5   train_loss = 0.368
Epoch 714 Batch    0/5   train_loss = 0.358
Epoch 716 Batch    0/5   train_loss = 0.352
Epoch 718 Batch    0/5   train_loss = 0.349
Epoch 720 Batch    0/5   train_loss = 0.344
Epoch 722 Batch    0/5   train_loss = 0.346
Epoch 724 Batch    0/5   train_loss = 0.345
Epoch 726 Batch    0/5   train_loss = 0.337
Epoch 728 Batch    0/5   train_loss = 0.345
Epoch 730 Batch    0/5   train_loss = 0.348
Epoch 732 Batch    0/5   train_loss = 0.358
Epoch 734 Batch    0/5   train_loss = 0.346
Epoch 736 Batch    0/5   train_loss = 0.337
Epoch 738 Batch    0/5   train_loss = 0.329
Epoch 740 Batch    0/5   train_loss = 0.320
Epoch 742 Batch    0/5   train_loss = 0.323
Epoch 744 Batch    0/5   train_loss = 0.316
Epoch 746 Batch    0/5   train_loss = 0.304
Epoch 748 Batch    0/5   train_loss = 0.299
Epoch 750 Batch    0/5   train_loss = 0.292
Epoch 752 Batch    0/5   train_loss = 0.288
Epoch 754 Batch    0/5   train_loss = 0.289
Epoch 756 Batch    0/5   train_loss = 0.284
Epoch 758 Batch    0/5   train_loss = 0.290
Epoch 760 Batch    0/5   train_loss = 0.304
Epoch 762 Batch    0/5   train_loss = 0.311
Epoch 764 Batch    0/5   train_loss = 0.405
Epoch 766 Batch    0/5   train_loss = 0.390
Epoch 768 Batch    0/5   train_loss = 0.344
Epoch 770 Batch    0/5   train_loss = 0.320
Epoch 772 Batch    0/5   train_loss = 0.280
Epoch 774 Batch    0/5   train_loss = 0.265
Epoch 776 Batch    0/5   train_loss = 0.258
Epoch 778 Batch    0/5   train_loss = 0.252
Epoch 780 Batch    0/5   train_loss = 0.247
Epoch 782 Batch    0/5   train_loss = 0.243
Epoch 784 Batch    0/5   train_loss = 0.240
Epoch 786 Batch    0/5   train_loss = 0.237
Epoch 788 Batch    0/5   train_loss = 0.233
Epoch 790 Batch    0/5   train_loss = 0.231
Epoch 792 Batch    0/5   train_loss = 0.229
Epoch 794 Batch    0/5   train_loss = 0.225
Epoch 796 Batch    0/5   train_loss = 0.230
Epoch 798 Batch    0/5   train_loss = 0.226
Epoch 800 Batch    0/5   train_loss = 0.222
Epoch 802 Batch    0/5   train_loss = 0.237
Epoch 804 Batch    0/5   train_loss = 0.225
Epoch 806 Batch    0/5   train_loss = 0.225
Epoch 808 Batch    0/5   train_loss = 0.245
Epoch 810 Batch    0/5   train_loss = 0.227
Epoch 812 Batch    0/5   train_loss = 0.210
Epoch 814 Batch    0/5   train_loss = 0.206
Epoch 816 Batch    0/5   train_loss = 0.202
Epoch 818 Batch    0/5   train_loss = 0.198
Epoch 820 Batch    0/5   train_loss = 0.195
Epoch 822 Batch    0/5   train_loss = 0.192
Epoch 824 Batch    0/5   train_loss = 0.189
Epoch 826 Batch    0/5   train_loss = 0.189
Epoch 828 Batch    0/5   train_loss = 0.187
Epoch 830 Batch    0/5   train_loss = 0.186
Epoch 832 Batch    0/5   train_loss = 0.187
Epoch 834 Batch    0/5   train_loss = 0.189
Epoch 836 Batch    0/5   train_loss = 0.189
Epoch 838 Batch    0/5   train_loss = 0.197
Epoch 840 Batch    0/5   train_loss = 0.207
Epoch 842 Batch    0/5   train_loss = 0.196
Epoch 844 Batch    0/5   train_loss = 0.187
Epoch 846 Batch    0/5   train_loss = 0.197
Epoch 848 Batch    0/5   train_loss = 0.189
Epoch 850 Batch    0/5   train_loss = 0.176
Epoch 852 Batch    0/5   train_loss = 0.171
Epoch 854 Batch    0/5   train_loss = 0.164
Epoch 856 Batch    0/5   train_loss = 0.161
Epoch 858 Batch    0/5   train_loss = 0.157
Epoch 860 Batch    0/5   train_loss = 0.154
Epoch 862 Batch    0/5   train_loss = 0.152
Epoch 864 Batch    0/5   train_loss = 0.150
Epoch 866 Batch    0/5   train_loss = 0.148
Epoch 868 Batch    0/5   train_loss = 0.146
Epoch 870 Batch    0/5   train_loss = 0.145
Epoch 872 Batch    0/5   train_loss = 0.145
Epoch 874 Batch    0/5   train_loss = 0.142
Epoch 876 Batch    0/5   train_loss = 0.143
Epoch 878 Batch    0/5   train_loss = 0.159
Epoch 880 Batch    0/5   train_loss = 0.145
Epoch 882 Batch    0/5   train_loss = 0.161
Epoch 884 Batch    0/5   train_loss = 0.211
Epoch 886 Batch    0/5   train_loss = 0.196
Epoch 888 Batch    0/5   train_loss = 0.335
Epoch 890 Batch    0/5   train_loss = 0.325
Epoch 892 Batch    0/5   train_loss = 0.279
Epoch 894 Batch    0/5   train_loss = 0.244
Epoch 896 Batch    0/5   train_loss = 0.214
Epoch 898 Batch    0/5   train_loss = 0.174
Epoch 900 Batch    0/5   train_loss = 0.147
Epoch 902 Batch    0/5   train_loss = 0.138
Epoch 904 Batch    0/5   train_loss = 0.131
Epoch 906 Batch    0/5   train_loss = 0.128
Epoch 908 Batch    0/5   train_loss = 0.125
Epoch 910 Batch    0/5   train_loss = 0.123
Epoch 912 Batch    0/5   train_loss = 0.121
Epoch 914 Batch    0/5   train_loss = 0.119
Epoch 916 Batch    0/5   train_loss = 0.117
Epoch 918 Batch    0/5   train_loss = 0.116
Epoch 920 Batch    0/5   train_loss = 0.114
Epoch 922 Batch    0/5   train_loss = 0.113
Epoch 924 Batch    0/5   train_loss = 0.112
Epoch 926 Batch    0/5   train_loss = 0.111
Epoch 928 Batch    0/5   train_loss = 0.109
Epoch 930 Batch    0/5   train_loss = 0.108
Epoch 932 Batch    0/5   train_loss = 0.107
Epoch 934 Batch    0/5   train_loss = 0.106
Epoch 936 Batch    0/5   train_loss = 0.105
Epoch 938 Batch    0/5   train_loss = 0.103
Epoch 940 Batch    0/5   train_loss = 0.102
Epoch 942 Batch    0/5   train_loss = 0.101
Epoch 944 Batch    0/5   train_loss = 0.100
Epoch 946 Batch    0/5   train_loss = 0.099
Epoch 948 Batch    0/5   train_loss = 0.098
Epoch 950 Batch    0/5   train_loss = 0.097
Epoch 952 Batch    0/5   train_loss = 0.096
Epoch 954 Batch    0/5   train_loss = 0.095
Epoch 956 Batch    0/5   train_loss = 0.094
Epoch 958 Batch    0/5   train_loss = 0.094
Epoch 960 Batch    0/5   train_loss = 0.093
Epoch 962 Batch    0/5   train_loss = 0.092
Epoch 964 Batch    0/5   train_loss = 0.091
Epoch 966 Batch    0/5   train_loss = 0.090
Epoch 968 Batch    0/5   train_loss = 0.089
Epoch 970 Batch    0/5   train_loss = 0.088
Epoch 972 Batch    0/5   train_loss = 0.088
Epoch 974 Batch    0/5   train_loss = 0.087
Epoch 976 Batch    0/5   train_loss = 0.086
Epoch 978 Batch    0/5   train_loss = 0.085
Epoch 980 Batch    0/5   train_loss = 0.084
Epoch 982 Batch    0/5   train_loss = 0.083
Epoch 984 Batch    0/5   train_loss = 0.083
Epoch 986 Batch    0/5   train_loss = 0.082
Epoch 988 Batch    0/5   train_loss = 0.081
Epoch 990 Batch    0/5   train_loss = 0.080
Epoch 992 Batch    0/5   train_loss = 0.080
Epoch 994 Batch    0/5   train_loss = 0.079
Epoch 996 Batch    0/5   train_loss = 0.078
Epoch 998 Batch    0/5   train_loss = 0.078
Model Trained and Saved

Save Parameters

Save seq_length and save_dir for generating a new TV script.

In [198]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
# Save parameters for checkpoint
helper.save_params((seq_length, save_dir))

Checkpoint

In [272]:
"""
DON'T MODIFY ANYTHING IN THIS CELL
"""
import tensorflow as tf
import numpy as np
import helper
import problem_unittests as tests

_, vocab_to_int, int_to_vocab, token_dict = helper.load_preprocess()
seq_length, load_dir = helper.load_params()

Implement Generate Functions

Get Tensors

Get tensors from loaded_graph using the function get_tensor_by_name(). Get the tensors using the following names:

  • "input:0"
  • "initial_state:0"
  • "final_state:0"
  • "probs:0"

Return the tensors in the following tuple (InputTensor, InitialStateTensor, FinalStateTensor, ProbsTensor)

In [273]:
def get_tensors(loaded_graph):
    """
    Get input, initial state, final state, and probabilities tensor from <loaded_graph>
    :param loaded_graph: TensorFlow graph loaded from file
    :return: Tuple (InputTensor, InitialStateTensor, FinalStateTensor, ProbsTensor)
    """
    
    t_input = loaded_graph.get_tensor_by_name('input:0')
    t_initial_state = loaded_graph.get_tensor_by_name('initial_state:0')
    t_final_state = loaded_graph.get_tensor_by_name('final_state:0')
    t_probs = loaded_graph.get_tensor_by_name('probs:0')
    return t_input, t_initial_state, t_final_state, t_probs


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_get_tensors(get_tensors)
Tests Passed

Choose Word

Implement the pick_word() function to select the next word using probabilities.

In [274]:
def pick_word(probabilities, int_to_vocab):
    """
    Pick the next word in the generated text
    :param probabilities: Probabilites of the next word
    :param int_to_vocab: Dictionary of word ids as the keys and words as the values
    :return: String of the predicted word
    """
    
    word = int_to_vocab[np.argmax(probabilities)]
    
    return word


"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
tests.test_pick_word(pick_word)
Tests Passed

Generate TV Script

This will generate the TV script for you. Set gen_length to the length of TV script you want to generate.

In [275]:
gen_length = 200
# homer_simpson, moe_szyslak, or Barney_Gumble
prime_word = 'moe_szyslak'

"""
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE
"""
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
    # Load saved model
    loader = tf.train.import_meta_graph(load_dir + '.meta')
    loader.restore(sess, load_dir)

    # Get Tensors from loaded model
    input_text, initial_state, final_state, probs = get_tensors(loaded_graph)

    # Sentences generation setup
    gen_sentences = [prime_word + ':']
    prev_state = sess.run(initial_state, {input_text: np.array([[1]])})

    # Generate sentences
    for n in range(gen_length):
        # Dynamic Input
        dyn_input = [[vocab_to_int[word] for word in gen_sentences[-seq_length:]]]
        dyn_seq_length = len(dyn_input[0])

        # Get Prediction
        probabilities, prev_state = sess.run(
            [probs, final_state],
            {input_text: dyn_input, initial_state: prev_state})
        
        pred_word = pick_word(probabilities[dyn_seq_length-1], int_to_vocab)

        gen_sentences.append(pred_word)
    
    # Remove tokens
    tv_script = ' '.join(gen_sentences)
    for key, token in token_dict.items():
        ending = ' ' if key in ['\n', '(', '"'] else ''
        tv_script = tv_script.replace(' ' + token.lower(), key)
    tv_script = tv_script.replace('\n ', '\n')
    tv_script = tv_script.replace('( ', '(')
        
    print(tv_script)
moe_szyslak: sizes good-looking slap detective_homer_simpson: takin' cesss planning parrot smoke parrot sizes frustrated choked slap gesture elmo's jerry duff's butterball officials sizes themselves gesture whiny irrelevant paintings continuing huddle tony butterball worst jerry neighborhood slap slap slap detective_homer_simpson: meatpies crooks sail slap slap slap sizes worst mr slap worst gesture parrot calendars bathed schnapps butterball stuck jerry dash my-y-y-y-y-y slap slap slap detective_homer_simpson: rain gesture bashir's jerry longest slap slap slap detective_homer_simpson: realize gesture parrot neighborhood jerry dad's poet presided scrutinizes presided rope neighborhood booth detective_homer_simpson: enjoyed gesture electronic sam: jerry dash my-y-y-y-y-y butterball protestantism dash my-y-y-y-y-y friendly dash happiness agreement slap protestantism muttering muttering sugar-free parrot is: abandon fudd scrutinizes detective_homer_simpson: itself duff's butterball drinker slap muttering shaky slap cuff giant face knockin' tv-station_announcer: that's slap detective_homer_simpson: celebrate rubbed 2nd_voice_on_transmitter: further rubbed usual laramie bunch slap detective_homer_simpson: itself gesture child jerry premise poet sarcastic slap detective_homer_simpson: meatpies skydiving scrutinizes scream renee: scrutinizes detective_homer_simpson: itself lenses butterball tapered smokin' 2nd_voice_on_transmitter: slap detective_homer_simpson: detective_homer_simpson: detective_homer_simpson: aims always butterball oh-so-sophisticated wine dislike sizes bury gang butterball renee: rope laramie themselves beings slap detective_homer_simpson: rain indicates butterball stunned slap detective_homer_simpson: rain arts butterball ratted 2nd_voice_on_transmitter: pepsi oh-so-sophisticated planning booth rope presided rope abandon worst

The TV Script is Nonsensical

It's ok if the TV script doesn't make any sense. We trained on less than a megabyte of text. In order to get good results, you'll have to use a smaller vocabulary or get more data. Luckly there's more data! As we mentioned in the begging of this project, this is a subset of another dataset. We didn't have you train on all the data, because that would take too long. However, you are free to train your neural network on all the data. After you complete the project, of course.

Submitting This Project

When submitting this project, make sure to run all the cells before saving the notebook. Save the notebook file as "dlnd_tv_script_generation.ipynb" and save it as a HTML file under "File" -> "Download as". Include the "helper.py" and "problem_unittests.py" files in your submission.