.. _sec_gru:
Gated Recurrent Units (GRU)
===========================
As RNNs and particularly the LSTM architecture (:numref:`sec_lstm`)
rapidly gained popularity during the 2010s, a number of researchers
began to experiment with simplified architectures in hopes of retaining
the key idea of incorporating an internal state and multiplicative
gating mechanisms but with the aim of speeding up computation. The gated
recurrent unit (GRU) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014`
offered a streamlined version of the LSTM memory cell that often
achieves comparable performance but with the advantage of being faster
to compute :cite:`Chung.Gulcehre.Cho.ea.2014`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Reset Gate and Update Gate
--------------------------
Here, the LSTM’s three gates are replaced by two: the *reset gate* and
the *update gate*. As with LSTMs, these gates are given sigmoid
activations, forcing their values to lie in the interval :math:`(0, 1)`.
Intuitively, the reset gate controls how much of the previous state we
might still want to remember. Likewise, an update gate would allow us to
control how much of the new state is just a copy of the old one.
:numref:`fig_gru_1` illustrates the inputs for both the reset and
update gates in a GRU, given the input of the current time step and the
hidden state of the previous time step. The outputs of the gates are
given by two fully connected layers with a sigmoid activation function.
.. _fig_gru_1:
.. figure:: ../img/gru-1.svg
Computing the reset gate and the update gate in a GRU model.
Mathematically, for a given time step :math:`t`, suppose that the input
is a minibatch :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (number
of examples :math:`=n`; number of inputs :math:`=d`) and the hidden
state of the previous time step is
:math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}` (number of hidden
units :math:`=h`). Then the reset gate
:math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` and update gate
:math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` are computed as
follows:
.. math::
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xr}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hr}} + \mathbf{b}_\textrm{r}),\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xz}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hz}} + \mathbf{b}_\textrm{z}),
\end{aligned}
where
:math:`\mathbf{W}_{\textrm{xr}}, \mathbf{W}_{\textrm{xz}} \in \mathbb{R}^{d \times h}`
and
:math:`\mathbf{W}_{\textrm{hr}}, \mathbf{W}_{\textrm{hz}} \in \mathbb{R}^{h \times h}`
are weight parameters and
:math:`\mathbf{b}_\textrm{r}, \mathbf{b}_\textrm{z} \in \mathbb{R}^{1 \times h}`
are bias parameters.
Candidate Hidden State
----------------------
Next, we integrate the reset gate :math:`\mathbf{R}_t` with the regular
updating mechanism in :eq:`rnn_h_with_state`, leading to the
following *candidate hidden state*
:math:`\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}` at time step
:math:`t`:
.. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{\textrm{xh}} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{\textrm{hh}} + \mathbf{b}_\textrm{h}),
:label: gru_tilde_H
where :math:`\mathbf{W}_{\textrm{xh}} \in \mathbb{R}^{d \times h}` and
:math:`\mathbf{W}_{\textrm{hh}} \in \mathbb{R}^{h \times h}` are weight
parameters, :math:`\mathbf{b}_\textrm{h} \in \mathbb{R}^{1 \times h}` is
the bias, and the symbol :math:`\odot` is the Hadamard (elementwise)
product operator. Here we use a tanh activation function.
The result is a *candidate*, since we still need to incorporate the
action of the update gate. Comparing with :eq:`rnn_h_with_state`,
the influence of the previous states can now be reduced with the
elementwise multiplication of :math:`\mathbf{R}_t` and
:math:`\mathbf{H}_{t-1}` in :eq:`gru_tilde_H`. Whenever the entries
in the reset gate :math:`\mathbf{R}_t` are close to 1, we recover a
vanilla RNN such as that in :eq:`rnn_h_with_state`. For all entries
of the reset gate :math:`\mathbf{R}_t` that are close to 0, the
candidate hidden state is the result of an MLP with :math:`\mathbf{X}_t`
as input. Any pre-existing hidden state is thus *reset* to defaults.
:numref:`fig_gru_2` illustrates the computational flow after applying
the reset gate.
.. _fig_gru_2:
.. figure:: ../img/gru-2.svg
Computing the candidate hidden state in a GRU model.
Hidden State
------------
Finally, we need to incorporate the effect of the update gate
:math:`\mathbf{Z}_t`. This determines the extent to which the new hidden
state :math:`\mathbf{H}_t \in \mathbb{R}^{n \times h}` matches the old
state :math:`\mathbf{H}_{t-1}` compared with how much it resembles the
new candidate state :math:`\tilde{\mathbf{H}}_t`. The update gate
:math:`\mathbf{Z}_t` can be used for this purpose, simply by taking
elementwise convex combinations of :math:`\mathbf{H}_{t-1}` and
:math:`\tilde{\mathbf{H}}_t`. This leads to the final update equation
for the GRU:
.. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
Whenever the update gate :math:`\mathbf{Z}_t` is close to 1, we simply
retain the old state. In this case the information from
:math:`\mathbf{X}_t` is ignored, effectively skipping time step
:math:`t` in the dependency chain. By contrast, whenever
:math:`\mathbf{Z}_t` is close to 0, the new latent state
:math:`\mathbf{H}_t` approaches the candidate latent state
:math:`\tilde{\mathbf{H}}_t`. :numref:`fig_gru_3` shows the
computational flow after the update gate is in action.
.. _fig_gru_3:
.. figure:: ../img/gru-3.svg
Computing the hidden state in a GRU model.
In summary, GRUs have the following two distinguishing features:
- Reset gates help capture short-term dependencies in sequences.
- Update gates help capture long-term dependencies in sequences.
Implementation from Scratch
---------------------------
To gain a better understanding of the GRU model, let’s implement it from
scratch.
Initializing Model Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The first step is to initialize the model parameters. We draw the
weights from a Gaussian distribution with standard deviation to be
``sigma`` and set the bias to 0. The hyperparameter ``num_hiddens``
defines the number of hidden units. We instantiate all weights and
biases relating to the update gate, the reset gate, and the candidate
hidden state.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
nn.Parameter(torch.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: np.random.randn(*shape) * sigma
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
np.zeros(num_hiddens))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRUScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple('z') # Update gate
self.W_xr, self.W_hr, self.b_r = triple('r') # Reset gate
self.W_xh, self.W_hh, self.b_h = triple('h') # Candidate hidden state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: tf.Variable(tf.random.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(tf.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
.. raw:: html
.. raw:: html
Defining the Model
~~~~~~~~~~~~~~~~~~
Now we are ready to define the GRU forward computation. Its structure is
the same as that of the basic RNN cell, except that the update equations
are more complex.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
outputs = []
for X in inputs:
Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
torch.matmul(H, self.W_hz) + self.b_z)
R = torch.sigmoid(torch.matmul(X, self.W_xr) +
torch.matmul(H, self.W_hr) + self.b_r)
H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +
torch.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
outputs = []
for X in inputs:
Z = npx.sigmoid(np.dot(X, self.W_xz) +
np.dot(H, self.W_hz) + self.b_z)
R = npx.sigmoid(np.dot(X, self.W_xr) +
np.dot(H, self.W_hr) + self.b_r)
H_tilde = np.tanh(np.dot(X, self.W_xh) +
np.dot(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation
def scan_fn(H, X):
Z = jax.nn.sigmoid(jnp.matmul(X, self.W_xz) + jnp.matmul(H, self.W_hz) +
self.b_z)
R = jax.nn.sigmoid(jnp.matmul(X, self.W_xr) +
jnp.matmul(H, self.W_hr) + self.b_r)
H_tilde = jnp.tanh(jnp.matmul(X, self.W_xh) +
jnp.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
return H, H # return carry, y
if H is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carry
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = tf.zeros((inputs.shape[1], self.num_hiddens))
outputs = []
for X in inputs:
Z = tf.sigmoid(tf.matmul(X, self.W_xz) +
tf.matmul(H, self.W_hz) + self.b_z)
R = tf.sigmoid(tf.matmul(X, self.W_xr) +
tf.matmul(H, self.W_hr) + self.b_r)
H_tilde = tf.tanh(tf.matmul(X, self.W_xh) +
tf.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
.. raw:: html
.. raw:: html
Training
~~~~~~~~
Training a language model on *The Time Machine* dataset works in exactly
the same manner as in :numref:`sec_rnn-scratch`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_48_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_51_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_54_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_57_0.svg
.. raw:: html
.. raw:: html
Concise Implementation
----------------------
In high-level APIs, we can directly instantiate a GRU model. This
encapsulates all the configuration detail that we made explicit above.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = nn.GRU(num_inputs, num_hiddens)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = rnn.GRU(num_hiddens)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRU(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None, training=False):
if H is None:
batch_size = inputs.shape[1]
H = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0),
(batch_size,), self.num_hiddens)
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H, outputs = GRU()(H, inputs)
return outputs, H
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True,
return_state=True)
.. raw:: html
.. raw:: html
The code is significantly faster in training as it uses compiled
operators rather than Python.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_78_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_81_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
gru = GRU(num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_84_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
with d2l.try_gpu():
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
.. figure:: output_gru_b77a34_87_0.svg
.. raw:: html
.. raw:: html
After training, we print out the perplexity on the training set and the
predicted sequence following the provided prefix.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has so it and the time '
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has i have the time tra'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, trainer.state.params)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has is a move and a mov'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has t t t t t t t t t t'
.. raw:: html
.. raw:: html
Summary
-------
Compared with LSTMs, GRUs achieve similar performance but tend to be
lighter computationally. Generally, compared with simple RNNs, gated
RNNS, just like LSTMs and GRUs, can better capture dependencies for
sequences with large time step distances. GRUs contain basic RNNs as
their extreme case whenever the reset gate is switched on. They can also
skip subsequences by turning on the update gate.
Exercises
---------
1. Assume that we only want to use the input at time step :math:`t'` to
predict the output at time step :math:`t > t'`. What are the best
values for the reset and update gates for each time step?
2. Adjust the hyperparameters and analyze their influence on running
time, perplexity, and the output sequence.
3. Compare runtime, perplexity, and the output strings for ``rnn.RNN``
and ``rnn.GRU`` implementations with each other.
4. What happens if you implement only parts of a GRU, e.g., with only a
reset gate or only an update gate?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html