.. _sec_natural-language-inference-attention:
Natural Language Inference: Using Attention
===========================================
We introduced the natural language inference task and the SNLI dataset
in :numref:`sec_natural-language-inference-and-dataset`. In view of
many models that are based on complex and deep architectures,
:cite:t:`Parikh.Tackstrom.Das.ea.2016` proposed to address natural
language inference with attention mechanisms and called it a
“decomposable attention model”. This results in a model without
recurrent or convolutional layers, achieving the best result at the time
on the SNLI dataset with much fewer parameters. In this section, we will
describe and implement this attention-based method (with MLPs) for
natural language inference, as depicted in
:numref:`fig_nlp-map-nli-attention`.
.. _fig_nlp-map-nli-attention:
.. figure:: ../img/nlp-map-nli-attention.svg
This section feeds pretrained GloVe to an architecture based on
attention and MLPs for natural language inference.
The Model
---------
Simpler than preserving the order of tokens in premises and hypotheses,
we can just align tokens in one text sequence to every token in the
other, and vice versa, then compare and aggregate such information to
predict the logical relationships between premises and hypotheses.
Similar to alignment of tokens between source and target sentences in
machine translation, the alignment of tokens between premises and
hypotheses can be neatly accomplished by attention mechanisms.
.. _fig_nli_attention:
.. figure:: ../img/nli-attention.svg
Natural language inference using attention mechanisms.
:numref:`fig_nli_attention` depicts the natural language inference
method using attention mechanisms. At a high level, it consists of three
jointly trained steps: attending, comparing, and aggregating. We will
illustrate them step by step in the following.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
Attending
~~~~~~~~~
The first step is to align tokens in one text sequence to each token in
the other sequence. Suppose that the premise is “i do need sleep” and
the hypothesis is “i am tired”. Due to semantical similarity, we may
wish to align “i” in the hypothesis with “i” in the premise, and align
“tired” in the hypothesis with “sleep” in the premise. Likewise, we may
wish to align “i” in the premise with “i” in the hypothesis, and align
“need” and “sleep” in the premise with “tired” in the hypothesis. Note
that such alignment is *soft* using weighted average, where ideally
large weights are associated with the tokens to be aligned. For ease of
demonstration, :numref:`fig_nli_attention` shows such alignment in a
*hard* way.
Now we describe the soft alignment using attention mechanisms in more
detail. Denote by
:math:`\mathbf{A} = (\mathbf{a}_1, \ldots, \mathbf{a}_m)` and
:math:`\mathbf{B} = (\mathbf{b}_1, \ldots, \mathbf{b}_n)` the premise
and hypothesis, whose number of tokens are :math:`m` and :math:`n`,
respectively, where
:math:`\mathbf{a}_i, \mathbf{b}_j \in \mathbb{R}^{d}`
(:math:`i = 1, \ldots, m, j = 1, \ldots, n`) is a :math:`d`-dimensional
word vector. For soft alignment, we compute the attention weights
:math:`e_{ij} \in \mathbb{R}` as
.. math:: e_{ij} = f(\mathbf{a}_i)^\top f(\mathbf{b}_j),
:label: eq_nli_e
where the function :math:`f` is an MLP defined in the following ``mlp``
function. The output dimension of :math:`f` is specified by the
``num_hiddens`` argument of ``mlp``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def mlp(num_inputs, num_hiddens, flatten):
net = []
net.append(nn.Dropout(0.2))
net.append(nn.Linear(num_inputs, num_hiddens))
net.append(nn.ReLU())
if flatten:
net.append(nn.Flatten(start_dim=1))
net.append(nn.Dropout(0.2))
net.append(nn.Linear(num_hiddens, num_hiddens))
net.append(nn.ReLU())
if flatten:
net.append(nn.Flatten(start_dim=1))
return nn.Sequential(*net)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def mlp(num_hiddens, flatten):
net = nn.Sequential()
net.add(nn.Dropout(0.2))
net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
net.add(nn.Dropout(0.2))
net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
return net
.. raw:: html
.. raw:: html
It should be highlighted that, in :eq:`eq_nli_e` :math:`f` takes
inputs :math:`\mathbf{a}_i` and :math:`\mathbf{b}_j` separately rather
than takes a pair of them together as input. This *decomposition* trick
leads to only :math:`m + n` applications (linear complexity) of
:math:`f` rather than :math:`mn` applications (quadratic complexity).
Normalizing the attention weights in :eq:`eq_nli_e`, we compute the
weighted average of all the token vectors in the hypothesis to obtain
representation of the hypothesis that is softly aligned with the token
indexed by :math:`i` in the premise:
.. math::
\boldsymbol{\beta}_i = \sum_{j=1}^{n}\frac{\exp(e_{ij})}{ \sum_{k=1}^{n} \exp(e_{ik})} \mathbf{b}_j.
Likewise, we compute soft alignment of premise tokens for each token
indexed by :math:`j` in the hypothesis:
.. math::
\boldsymbol{\alpha}_j = \sum_{i=1}^{m}\frac{\exp(e_{ij})}{ \sum_{k=1}^{m} \exp(e_{kj})} \mathbf{a}_i.
Below we define the ``Attend`` class to compute the soft alignment of
hypotheses (``beta``) with input premises ``A`` and soft alignment of
premises (``alpha``) with input hypotheses ``B``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Attend(nn.Module):
def __init__(self, num_inputs, num_hiddens, **kwargs):
super(Attend, self).__init__(**kwargs)
self.f = mlp(num_inputs, num_hiddens, flatten=False)
def forward(self, A, B):
# Shape of `A`/`B`: (`batch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = self.f(A)
f_B = self.f(B)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = torch.bmm(f_A, f_B.permute(0, 2, 1))
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = torch.bmm(F.softmax(e, dim=-1), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A)
return beta, alpha
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Attend(nn.Block):
def __init__(self, num_hiddens, **kwargs):
super(Attend, self).__init__(**kwargs)
self.f = mlp(num_hiddens=num_hiddens, flatten=False)
def forward(self, A, B):
# Shape of `A`/`B`: (b`atch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = self.f(A)
f_B = self.f(B)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = npx.batch_dot(f_A, f_B, transpose_b=True)
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = npx.batch_dot(npx.softmax(e), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = npx.batch_dot(npx.softmax(e.transpose(0, 2, 1)), A)
return beta, alpha
.. raw:: html
.. raw:: html
Comparing
~~~~~~~~~
In the next step, we compare a token in one sequence with the other
sequence that is softly aligned with that token. Note that in soft
alignment, all the tokens from one sequence, though with probably
different attention weights, will be compared with a token in the other
sequence. For easy of demonstration, :numref:`fig_nli_attention` pairs
tokens with aligned tokens in a *hard* way. For example, suppose that
the attending step determines that “need” and “sleep” in the premise are
both aligned with “tired” in the hypothesis, the pair “tired–need sleep”
will be compared.
In the comparing step, we feed the concatenation (operator
:math:`[\cdot, \cdot]`) of tokens from one sequence and aligned tokens
from the other sequence into a function :math:`g` (an MLP):
.. math:: \mathbf{v}_{A,i} = g([\mathbf{a}_i, \boldsymbol{\beta}_i]), i = 1, \ldots, m\\ \mathbf{v}_{B,j} = g([\mathbf{b}_j, \boldsymbol{\alpha}_j]), j = 1, \ldots, n.
:label: eq_nli_v_ab
In :eq:`eq_nli_v_ab`, :math:`\mathbf{v}_{A,i}` is the comparison
between token :math:`i` in the premise and all the hypothesis tokens
that are softly aligned with token :math:`i`; while
:math:`\mathbf{v}_{B,j}` is the comparison between token :math:`j` in
the hypothesis and all the premise tokens that are softly aligned with
token :math:`j`. The following ``Compare`` class defines such as
comparing step.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Compare(nn.Module):
def __init__(self, num_inputs, num_hiddens, **kwargs):
super(Compare, self).__init__(**kwargs)
self.g = mlp(num_inputs, num_hiddens, flatten=False)
def forward(self, A, B, beta, alpha):
V_A = self.g(torch.cat([A, beta], dim=2))
V_B = self.g(torch.cat([B, alpha], dim=2))
return V_A, V_B
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Compare(nn.Block):
def __init__(self, num_hiddens, **kwargs):
super(Compare, self).__init__(**kwargs)
self.g = mlp(num_hiddens=num_hiddens, flatten=False)
def forward(self, A, B, beta, alpha):
V_A = self.g(np.concatenate([A, beta], axis=2))
V_B = self.g(np.concatenate([B, alpha], axis=2))
return V_A, V_B
.. raw:: html
.. raw:: html
Aggregating
~~~~~~~~~~~
With two sets of comparison vectors :math:`\mathbf{v}_{A,i}`
(:math:`i = 1, \ldots, m`) and :math:`\mathbf{v}_{B,j}`
(:math:`j = 1, \ldots, n`) on hand, in the last step we will aggregate
such information to infer the logical relationship. We begin by summing
up both sets:
.. math::
\mathbf{v}_A = \sum_{i=1}^{m} \mathbf{v}_{A,i}, \quad \mathbf{v}_B = \sum_{j=1}^{n}\mathbf{v}_{B,j}.
Next we feed the concatenation of both summarization results into
function :math:`h` (an MLP) to obtain the classification result of the
logical relationship:
.. math::
\hat{\mathbf{y}} = h([\mathbf{v}_A, \mathbf{v}_B]).
The aggregation step is defined in the following ``Aggregate`` class.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Aggregate(nn.Module):
def __init__(self, num_inputs, num_hiddens, num_outputs, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.h = mlp(num_inputs, num_hiddens, flatten=True)
self.linear = nn.Linear(num_hiddens, num_outputs)
def forward(self, V_A, V_B):
# Sum up both sets of comparison vectors
V_A = V_A.sum(dim=1)
V_B = V_B.sum(dim=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = self.linear(self.h(torch.cat([V_A, V_B], dim=1)))
return Y_hat
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Aggregate(nn.Block):
def __init__(self, num_hiddens, num_outputs, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.h = mlp(num_hiddens=num_hiddens, flatten=True)
self.h.add(nn.Dense(num_outputs))
def forward(self, V_A, V_B):
# Sum up both sets of comparison vectors
V_A = V_A.sum(axis=1)
V_B = V_B.sum(axis=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = self.h(np.concatenate([V_A, V_B], axis=1))
return Y_hat
.. raw:: html
.. raw:: html
Putting It All Together
~~~~~~~~~~~~~~~~~~~~~~~
By putting the attending, comparing, and aggregating steps together, we
define the decomposable attention model to jointly train these three
steps.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DecomposableAttention(nn.Module):
def __init__(self, vocab, embed_size, num_hiddens, num_inputs_attend=100,
num_inputs_compare=200, num_inputs_agg=400, **kwargs):
super(DecomposableAttention, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
self.attend = Attend(num_inputs_attend, num_hiddens)
self.compare = Compare(num_inputs_compare, num_hiddens)
# There are 3 possible outputs: entailment, contradiction, and neutral
self.aggregate = Aggregate(num_inputs_agg, num_hiddens, num_outputs=3)
def forward(self, X):
premises, hypotheses = X
A = self.embedding(premises)
B = self.embedding(hypotheses)
beta, alpha = self.attend(A, B)
V_A, V_B = self.compare(A, B, beta, alpha)
Y_hat = self.aggregate(V_A, V_B)
return Y_hat
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DecomposableAttention(nn.Block):
def __init__(self, vocab, embed_size, num_hiddens, **kwargs):
super(DecomposableAttention, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
self.attend = Attend(num_hiddens)
self.compare = Compare(num_hiddens)
# There are 3 possible outputs: entailment, contradiction, and neutral
self.aggregate = Aggregate(num_hiddens, 3)
def forward(self, X):
premises, hypotheses = X
A = self.embedding(premises)
B = self.embedding(hypotheses)
beta, alpha = self.attend(A, B)
V_A, V_B = self.compare(A, B, beta, alpha)
Y_hat = self.aggregate(V_A, V_B)
return Y_hat
.. raw:: html
.. raw:: html
Training and Evaluating the Model
---------------------------------
Now we will train and evaluate the defined decomposable attention model
on the SNLI dataset. We begin by reading the dataset.
Reading the dataset
~~~~~~~~~~~~~~~~~~~
We download and read the SNLI dataset using the function defined in
:numref:`sec_natural-language-inference-and-dataset`. The batch size
and sequence length are set to :math:`256` and :math:`50`, respectively.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...
[21:49:40] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
Creating the Model
~~~~~~~~~~~~~~~~~~
We use the pretrained 100-dimensional GloVe embedding to represent the
input tokens. Thus, we predefine the dimension of vectors
:math:`\mathbf{a}_i` and :math:`\mathbf{b}_j` in :eq:`eq_nli_e` as
100. The output dimension of functions :math:`f` in :eq:`eq_nli_e`
and :math:`g` in :eq:`eq_nli_v_ab` is set to 200. Then we create a
model instance, initialize its parameters, and load the GloVe embedding
to initialize vectors of input tokens.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds);
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
net.initialize(init.Xavier(), ctx=devices)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.set_data(embeds)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[21:49:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[21:49:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...
.. raw:: html
.. raw:: html
Training and Evaluating the Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In contrast to the ``split_batch`` function in :numref:`sec_multi_gpu`
that takes single inputs such as text sequences (or images), we define a
``split_batch_multi_inputs`` function to take multiple inputs such as
premises and hypotheses in minibatches.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def split_batch_multi_inputs(X, y, devices):
"""Split multi-input `X` and `y` into multiple devices."""
X = list(zip(*[gluon.utils.split_and_load(
feature, devices, even_split=False) for feature in X]))
return (X, gluon.utils.split_and_load(y, devices, even_split=False))
.. raw:: html
.. raw:: html
Now we can train and evaluate the model on the SNLI dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lr, num_epochs = 0.001, 4
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.496, train acc 0.805, test acc 0.828
20383.2 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
.. figure:: output_natural-language-inference-attention_b907c4_81_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lr, num_epochs = 0.001, 4
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
split_batch_multi_inputs)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.514, train acc 0.797, test acc 0.814
4621.6 examples/sec on [gpu(0), gpu(1)]
.. figure:: output_natural-language-inference-attention_b907c4_84_1.svg
.. raw:: html
.. raw:: html
Using the Model
~~~~~~~~~~~~~~~
Finally, define the prediction function to output the logical
relationship between a pair of premise and hypothesis.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def predict_snli(net, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
net.eval()
premise = torch.tensor(vocab[premise], device=d2l.try_gpu())
hypothesis = torch.tensor(vocab[hypothesis], device=d2l.try_gpu())
label = torch.argmax(net([premise.reshape((1, -1)),
hypothesis.reshape((1, -1))]), dim=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def predict_snli(net, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
premise = np.array(vocab[premise], ctx=d2l.try_gpu())
hypothesis = np.array(vocab[hypothesis], ctx=d2l.try_gpu())
label = np.argmax(net([premise.reshape((1, -1)),
hypothesis.reshape((1, -1))]), axis=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'
.. raw:: html
.. raw:: html
We can use the trained model to obtain the natural language inference
result for a sample pair of sentences.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.'])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'contradiction'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.'])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'contradiction'
.. raw:: html
.. raw:: html
Summary
-------
- The decomposable attention model consists of three steps for
predicting the logical relationships between premises and hypotheses:
attending, comparing, and aggregating.
- With attention mechanisms, we can align tokens in one text sequence
to every token in the other, and vice versa. Such alignment is soft
using weighted average, where ideally large weights are associated
with the tokens to be aligned.
- The decomposition trick leads to a more desirable linear complexity
than quadratic complexity when computing attention weights.
- We can use pretrained word vectors as the input representation for
downstream natural language processing task such as natural language
inference.
Exercises
---------
1. Train the model with other combinations of hyperparameters. Can you
get better accuracy on the test set?
2. What are major drawbacks of the decomposable attention model for
natural language inference?
3. Suppose that we want to get the level of semantical similarity (e.g.,
a continuous value between 0 and 1) for any pair of sentences. How
shall we collect and label the dataset? Can you design a model with
attention mechanisms?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html