Step-by-step

Step by step recurrent neural network inference with Keras

We often use recurrent neural network architectures for use cases such as streaming, where we make a new prediction each time new information arrives. It’s common to give it a long context each time, which repeats work and is therefore computationally wasteful. In this blog post, I will illustrate multiple ways to perform inference step by step using recurrent architectures by remembering the states and therefore removing the need to unnecessarily compute the same steps over and over. My post Creating a custom training loop in tensorflow gives a quick start on the Keras functional API which I use for the main code snippets.

Context

Recurrent Neural Networks (RNNs) are a popular type of neural network which specialises in time series data. Just like our brains are able to interpret a series of images as a movie or a bunch of words as a sentence, so a recurrent neural network is able to remember the past inputs, which affect its output for future inputs. It does this by maintaining a hidden state which is fed forward to the next iteration in a recurrent fashion. The unit that takes the previous state and input as inputs and returns the context is called a cell.

To improve this basic principle, architectures such as the LSTM (Long Short-Term Memory) improve on the vanilla RNN cell by including a separate cell state and special gates that “remember” and “forget” to and from the cell state. This greatly improves the ability of these networks to remember long sequences. This post will focus on the LSTM cell due to the extra complexity from the second state but can be easily adapted to the other main recurrent architecture used, the GRU cell.

RNN Architectures
Basic architectures of RNN, GRU and LSTM cells, by Saurabh Rathor on medium

The wasteful approach

The wasteful approach provides the entire context for each time step. To predict timestep i, the model receives inputs [0..i] as context. Since the hidden state after time step k, 0<=k<i is deterministically the same, repeating the computation is unnecessary.

First, we need to define imports, fix the seed, set the parameters for the models and generate a random input that we will reuse with all the models:

# Imports
import tensorflow as tf
import tensorflow.keras.layers as L
import matplotlib.pyplot as plt
from time import perf_counter

# Fix the seed so it always generates the same X and weights
tf.random.set_seed(0)

# Define parameters
batch_size = 10
seq_length = 200
features = 4
lstm_1_units = 100
lstm_2_units = 200

# Generate a random input for all models
X = tf.random.uniform((batch_size, seq_length, features))

The “classic” way. A standard model, as expected.

input1 = L.Input((None, 4))

lstm1_1 = L.LSTM(lstm_1_units, return_sequences=True)
lstm1_1_o = lstm1_1(input1)

lstm2_1 = L.LSTM(lstm_2_units)
lstm2_1_o = lstm2_1(lstm1_1_o)

output_dense_1 = L.Dense(1, activation="sigmoid")
output1 = output_dense_1(lstm2_1_o)

model1 = tf.keras.Model(inputs=input1, outputs=output1)

start1 = perf_counter()
classic_results = []
for i in range(1, len(X[0])):
    result_batch = model1.predict(X[:, 0:i, :])
    classic_results.append(result_batch[0][0])
stop1 = perf_counter()

plt.plot(classic_results)
plt.title(f"Classic approach took {(stop1 - start1):.03f}s")
plt.show()
Classic approach output

To improve the performance, we often restrict the context by only providing [i-B..i]. This is significantly better computationally, though it also repeats most operations, and restricts the model from having access to as much of the data as potentially useful.

The stateful model

Keras offers “stateful=True” for LSTM layers which maintain state automatically.
This is good if you always have the same sequence order in the batch, e.g.,

  • you are streaming live data and you are always appending to the same sequence
  • you want to evaluate a long sequence without keeping the entire thing in memory

It needs to know its batch size to be stateful, therefore it is specified for Input

input2 = L.Input((None, 4), batch_size=batch_size)

lstm1_2 = L.LSTM(lstm_1_units, return_sequences=True, stateful=True)
lstm1_2_o = lstm1_2(input2)

lstm2_2 = L.LSTM(lstm_2_units, stateful=True)
lstm2_2_o = lstm2_2(lstm1_2_o)

output_dense_2 = L.Dense(1, activation="sigmoid")
output2 = output_dense_2(lstm2_2_o)

model2 = tf.keras.Model(inputs=input2, outputs=output2)

# Copy the weights from the first model
lstm1_2.set_weights(lstm1_1.get_weights())
lstm2_2.set_weights(lstm2_1.get_weights())
output_dense_2.set_weights(output_dense_1.get_weights())

start2 = perf_counter()

stateful_results = []
for i in range(len(X[0]) - 1):
    # Since it's stateful, we can get each time step one by one
    # entire batch, from i->i+1, entire feature list
    result_batch = model2.predict(X[:, i : i + 1, :])
    stateful_results.append(result_batch[0][0])

stop2 = perf_counter()

plt.plot(stateful_results)
plt.title(f"Stateful model took {(stop2 - start2):.03f}s")
plt.show()
Step by step - stateful approach output

Step by step inference using the functional API

Maybe items in the batch change, or are not always present
Manually managing the state is useful as it allows some items in the batch to continue processing
The stateful approach only allows reset_states(), but that clears all of them

using the functional model we can accept multiple inputs which we pass as c,h

This model takes as inputs the c,h for lstm1, lstm2
returns new c,h for lstm1, lstm2
An input for each c,h
And returns each c,h

Management of them can then be done in the prediction loop as required or saved for later.

input3 = L.Input((None, features))
lstm1_3_c_input = L.Input((lstm_1_units,))
lstm1_3_h_input = L.Input((lstm_1_units,))
lstm2_3_c_input = L.Input((lstm_2_units,))
lstm2_3_h_input = L.Input((lstm_2_units,))


lstm1_3 = L.LSTM(lstm_1_units, return_sequences=True, return_state=True)
lstm1_3_o, lstm1_3_h, lstm1_3_c = lstm1_3(
    input3, initial_state=[lstm1_3_h_input, lstm1_3_c_input]
)
lstm2_3 = L.LSTM(lstm_2_units, return_state=True)
lstm2_3_o, lstm2_3_h, lstm2_3_c = lstm2_3(
    lstm1_3_o, initial_state=[lstm2_3_h_input, lstm2_3_c_input]
)

output_dense_3 = L.Dense(1, activation="sigmoid")
output3 = output_dense_3(lstm2_3_o)

model4 = tf.keras.Model(
    inputs={
        "input": input3,
        "lstm1_h": lstm1_3_h_input,
        "lstm1_c": lstm1_3_c_input,
        "lstm2_h": lstm2_3_h_input,
        "lstm2_c": lstm2_3_c_input,
    },
    outputs={
        "output": output3,
        "lstm1_h": lstm1_3_h,
        "lstm1_c": lstm1_3_c,
        "lstm2_h": lstm2_3_h,
        "lstm2_c": lstm2_3_c,
    },
)

lstm1_3.set_weights(lstm1_1.get_weights())
lstm2_3.set_weights(lstm2_1.get_weights())
output_dense_3.set_weights(output_dense_1.get_weights())


# Initialise cell hidden states, default zeros
lstm1_3_h_val = tf.zeros((batch_size, lstm_1_units))
lstm1_3_c_val = tf.zeros((batch_size, lstm_1_units))
lstm2_3_h_val = tf.zeros((batch_size, lstm_2_units))
lstm2_3_c_val = tf.zeros((batch_size, lstm_2_units))

start3 = perf_counter()
manual_func_results = []
for i in range(len(X[0]) - 1):
    ts_res = model4.predict(
        {
            "input": X[:, i : i + 1, :],
            "lstm1_h": lstm1_3_h_val,
            "lstm1_c": lstm1_3_c_val,
            "lstm2_h": lstm2_3_h_val,
            "lstm2_c": lstm2_3_c_val,
        }
    )
    manual_func_results.append(ts_res["output"][0][0])
    lstm1_3_h_val = ts_res["lstm1_h"]
    lstm1_3_c_val = ts_res["lstm1_c"]
    lstm2_3_h_val = ts_res["lstm2_h"]
    lstm2_3_c_val = ts_res["lstm2_c"]
stop3 = perf_counter()

plt.plot(manual_func_results)
plt.title(f"Manual functional model took {(stop3 - start3):.03f}s")
plt.show()
Step by step - functional approach output

Step by step inference using the subclassed model

class SubclassedModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.lstm1 = L.LSTM(lstm_1_units, return_sequences=True, return_state=True)
        self.lstm2 = L.LSTM(lstm_2_units, return_state=True)
        self.output_dense = L.Dense(1, activation="sigmoid")

    def call(self, inputs):
        lstm1_o, lstm1_h, lstm1_c = self.lstm1(
            inputs["input"], initial_state=[inputs["lstm1_h"], inputs["lstm1_c"]]
        )
        lstm2_o, lstm2_h, lstm2_c = self.lstm2(
            lstm1_o, initial_state=[inputs["lstm2_h"], inputs["lstm2_c"]]
        )
        output_dense = self.output_dense(lstm2_o)
        return {
            "output": output_dense,
            "lstm1_h": lstm1_h,
            "lstm1_c": lstm1_c,
            "lstm2_h": lstm2_h,
            "lstm2_c": lstm2_c,
        }


model4 = SubclassedModel()
model4.build(
    input_shape={
        "input": (batch_size, None, features),
        "lstm1_h": (batch_size, lstm_1_units),
        "lstm1_c": (batch_size, lstm_1_units),
        "lstm2_h": (batch_size, lstm_2_units),
        "lstm2_c": (batch_size, lstm_2_units),
    }
)

model4.lstm1.set_weights(lstm1_1.get_weights())
model4.lstm2.set_weights(lstm2_1.get_weights())
model4.output_dense.set_weights(output_dense_1.get_weights())


lstm1_4_h_val = tf.zeros((batch_size, lstm_1_units))
lstm1_4_c_val = tf.zeros((batch_size, lstm_1_units))
lstm2_4_h_val = tf.zeros((batch_size, lstm_2_units))
lstm2_4_c_val = tf.zeros((batch_size, lstm_2_units))

state = {
    "lstm1_h": lstm1_4_h_val,
    "lstm1_c": lstm1_4_c_val,
    "lstm2_h": lstm2_4_h_val,
    "lstm2_c": lstm2_4_c_val,
}

start4 = perf_counter()
manual_subcl_results = []
for i in range(len(X[0]) - 1):
    state["input"] = X[:, i : i + 1, :]
    state = model4.predict(state)
    manual_subcl_results.append(state["output"][0][0])
stop4 = perf_counter()

plt.plot(manual_subcl_results)
plt.title(f"Subclassed model took {(stop4 - start4):.03f}")
plt.show()
Step by step - subclassed approach output

Ensuring all outputs were identical

assert tf.reduce_all(classic_results == stateful_results)
assert tf.reduce_all(stateful_results == manual_func_results)
assert tf.reduce_all(manual_func_results == manual_subcl_results)

Conclusion

Recurrent architectures are known for being slow, and it’s important to notice the speedup obtained even with such a small model. Knowing how to manually manage recurrent hidden states improves the performance of machine learning models used for streaming or deployed in stateless containers

Jupyter notebook here

Leave a Reply