.. _sec_natural-language-inference-bert: Natural Language Inference: Fine-Tuning BERT ============================================ In earlier sections of this chapter, we have designed an attention-based architecture (in :numref:`sec_natural-language-inference-attention`) for the natural language inference task on the SNLI dataset (as described in :numref:`sec_natural-language-inference-and-dataset`). Now we revisit this task by fine-tuning BERT. As discussed in :numref:`sec_finetuning-bert`, natural language inference is a sequence-level text pair classification problem, and fine-tuning BERT only requires an additional MLP-based architecture, as illustrated in :numref:`fig_nlp-map-nli-bert`. .. _fig_nlp-map-nli-bert: .. figure:: ../img/nlp-map-nli-bert.svg This section feeds pretrained BERT to an MLP-based architecture for natural language inference. In this section, we will download a pretrained small version of BERT, then fine-tune it for natural language inference on the SNLI dataset. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import json import multiprocessing import os import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import json import multiprocessing import os from mxnet import gluon, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
Loading Pretrained BERT ----------------------- We have explained how to pretrain BERT on the WikiText-2 dataset in :numref:`sec_bert-dataset` and :numref:`sec_bert-pretraining` (note that the original BERT model is pretrained on much bigger corpora). As discussed in :numref:`sec_bert-pretraining`, the original BERT model has hundreds of millions of parameters. In the following, we provide two versions of pretrained BERT: “bert.base” is about as big as the original BERT base model that requires a lot of computational resources to fine-tune, while “bert.small” is a small version to facilitate demonstration. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip', '225d66f04cae318b841a13d32af3acc165f253ac') d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip', 'c72329e68a732bef0452e4b96a1c341c8910f81f') .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip', '7b3820b35da691042e5d34c0971ac3edbd80d3f4') d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip', 'a4e718a47137ccd1809c9107ab4f5edd317bae2c') .. raw:: html
.. raw:: html
Either pretrained BERT model contains a “vocab.json” file that defines the vocabulary set and a “pretrained.params” file of the pretrained parameters. We implement the following ``load_pretrained_model`` function to load pretrained BERT parameters. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len, devices): data_dir = d2l.download_extract(pretrained_model) # Define an empty vocabulary to load the predefined vocabulary vocab = d2l.Vocab() vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json'))) vocab.token_to_idx = {token: idx for idx, token in enumerate( vocab.idx_to_token)} bert = d2l.BERTModel( len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4, num_blks=2, dropout=0.2, max_len=max_len) # Load pretrained BERT parameters bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params'))) return bert, vocab .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len, devices): data_dir = d2l.download_extract(pretrained_model) # Define an empty vocabulary to load the predefined vocabulary vocab = d2l.Vocab() vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json'))) vocab.token_to_idx = {token: idx for idx, token in enumerate( vocab.idx_to_token)} bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len) # Load pretrained BERT parameters bert.load_parameters(os.path.join(data_dir, 'pretrained.params'), ctx=devices) return bert, vocab .. raw:: html
.. raw:: html
To facilitate demonstration on most of machines, we will load and fine-tune the small version (“bert.small”) of the pretrained BERT in this section. In the exercise, we will show how to fine-tune the much larger “bert.base” to significantly improve the testing accuracy. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python devices = d2l.try_all_gpus() bert, vocab = load_pretrained_model( 'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_blks=2, dropout=0.1, max_len=512, devices=devices) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip... .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python devices = d2l.try_all_gpus() bert, vocab = load_pretrained_model( 'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_blks=2, dropout=0.1, max_len=512, devices=devices) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip... [21:49:07] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU [21:49:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU [21:49:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU .. raw:: html
.. raw:: html
The Dataset for Fine-Tuning BERT -------------------------------- For the downstream task natural language inference on the SNLI dataset, we define a customized dataset class ``SNLIBERTDataset``. In each example, the premise and hypothesis form a pair of text sequence and is packed into one BERT input sequence as depicted in :numref:`fig_bert-two-seqs`. Recall :numref:`subsec_bert_input_rep` that segment IDs are used to distinguish the premise and the hypothesis in a BERT input sequence. With the predefined maximum length of a BERT input sequence (``max_len``), the last token of the longer of the input text pair keeps getting removed until ``max_len`` is met. To accelerate generation of the SNLI dataset for fine-tuning BERT, we use 4 worker processes to generate training or testing examples in parallel. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class SNLIBERTDataset(torch.utils.data.Dataset): def __init__(self, dataset, max_len, vocab=None): all_premise_hypothesis_tokens = [[ p_tokens, h_tokens] for p_tokens, h_tokens in zip( *[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])] self.labels = torch.tensor(dataset[2]) self.vocab = vocab self.max_len = max_len (self.all_token_ids, self.all_segments, self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens) print('read ' + str(len(self.all_token_ids)) + ' examples') def _preprocess(self, all_premise_hypothesis_tokens): pool = multiprocessing.Pool(4) # Use 4 worker processes out = pool.map(self._mp_worker, all_premise_hypothesis_tokens) all_token_ids = [ token_ids for token_ids, segments, valid_len in out] all_segments = [segments for token_ids, segments, valid_len in out] valid_lens = [valid_len for token_ids, segments, valid_len in out] return (torch.tensor(all_token_ids, dtype=torch.long), torch.tensor(all_segments, dtype=torch.long), torch.tensor(valid_lens)) def _mp_worker(self, premise_hypothesis_tokens): p_tokens, h_tokens = premise_hypothesis_tokens self._truncate_pair_of_tokens(p_tokens, h_tokens) tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens) token_ids = self.vocab[tokens] + [self.vocab['']] \ * (self.max_len - len(tokens)) segments = segments + [0] * (self.max_len - len(segments)) valid_len = len(tokens) return token_ids, segments, valid_len def _truncate_pair_of_tokens(self, p_tokens, h_tokens): # Reserve slots for '', '', and '' tokens for the BERT # input while len(p_tokens) + len(h_tokens) > self.max_len - 3: if len(p_tokens) > len(h_tokens): p_tokens.pop() else: h_tokens.pop() def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx]), self.labels[idx] def __len__(self): return len(self.all_token_ids) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class SNLIBERTDataset(gluon.data.Dataset): def __init__(self, dataset, max_len, vocab=None): all_premise_hypothesis_tokens = [[ p_tokens, h_tokens] for p_tokens, h_tokens in zip( *[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])] self.labels = np.array(dataset[2]) self.vocab = vocab self.max_len = max_len (self.all_token_ids, self.all_segments, self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens) print('read ' + str(len(self.all_token_ids)) + ' examples') def _preprocess(self, all_premise_hypothesis_tokens): pool = multiprocessing.Pool(4) # Use 4 worker processes out = pool.map(self._mp_worker, all_premise_hypothesis_tokens) all_token_ids = [ token_ids for token_ids, segments, valid_len in out] all_segments = [segments for token_ids, segments, valid_len in out] valid_lens = [valid_len for token_ids, segments, valid_len in out] return (np.array(all_token_ids, dtype='int32'), np.array(all_segments, dtype='int32'), np.array(valid_lens)) def _mp_worker(self, premise_hypothesis_tokens): p_tokens, h_tokens = premise_hypothesis_tokens self._truncate_pair_of_tokens(p_tokens, h_tokens) tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens) token_ids = self.vocab[tokens] + [self.vocab['']] \ * (self.max_len - len(tokens)) segments = segments + [0] * (self.max_len - len(segments)) valid_len = len(tokens) return token_ids, segments, valid_len def _truncate_pair_of_tokens(self, p_tokens, h_tokens): # Reserve slots for '', '', and '' tokens for the BERT # input while len(p_tokens) + len(h_tokens) > self.max_len - 3: if len(p_tokens) > len(h_tokens): p_tokens.pop() else: h_tokens.pop() def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx]), self.labels[idx] def __len__(self): return len(self.all_token_ids) .. raw:: html
.. raw:: html
After downloading the SNLI dataset, we generate training and testing examples by instantiating the ``SNLIBERTDataset`` class. Such examples will be read in minibatches during training and testing of natural language inference. .. raw:: html
pytorchmxnet
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Reduce `batch_size` if there is an out of memory error. In the original BERT # model, `max_len` = 512 batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers() data_dir = d2l.download_extract('SNLI') train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size, num_workers=num_workers) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output read 549367 examples read 9824 examples .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Reduce `batch_size` if there is an out of memory error. In the original BERT # model, `max_len` = 512 batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers() data_dir = d2l.download_extract('SNLI') train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab) train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = gluon.data.DataLoader(test_set, batch_size, num_workers=num_workers) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output read 549367 examples read 9824 examples .. raw:: html
.. raw:: html
Fine-Tuning BERT ---------------- As :numref:`fig_bert-two-seqs` indicates, fine-tuning BERT for natural language inference requires only an extra MLP consisting of two fully connected layers (see ``self.hidden`` and ``self.output`` in the following ``BERTClassifier`` class). This MLP transforms the BERT representation of the special “” token, which encodes the information of both the premise and the hypothesis, into three outputs of natural language inference: entailment, contradiction, and neutral. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BERTClassifier(nn.Module): def __init__(self, bert): super(BERTClassifier, self).__init__() self.encoder = bert.encoder self.hidden = bert.hidden self.output = nn.LazyLinear(3) def forward(self, inputs): tokens_X, segments_X, valid_lens_x = inputs encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x) return self.output(self.hidden(encoded_X[:, 0, :])) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class BERTClassifier(nn.Block): def __init__(self, bert): super(BERTClassifier, self).__init__() self.encoder = bert.encoder self.hidden = bert.hidden self.output = nn.Dense(3) def forward(self, inputs): tokens_X, segments_X, valid_lens_x = inputs encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x) return self.output(self.hidden(encoded_X[:, 0, :])) .. raw:: html
.. raw:: html
In the following, the pretrained BERT model ``bert`` is fed into the ``BERTClassifier`` instance ``net`` for the downstream application. In common implementations of BERT fine-tuning, only the parameters of the output layer of the additional MLP (``net.output``) will be learned from scratch. All the parameters of the pretrained BERT encoder (``net.encoder``) and the hidden layer of the additional MLP (``net.hidden``) will be fine-tuned. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = BERTClassifier(bert) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = BERTClassifier(bert) net.output.initialize(ctx=devices) .. raw:: html
.. raw:: html
Recall that in :numref:`sec_bert` both the ``MaskLM`` class and the ``NextSentencePred`` class have parameters in their employed MLPs. These parameters are part of those in the pretrained BERT model ``bert``, and thus part of parameters in ``net``. However, such parameters are only for computing the masked language modeling loss and the next sentence prediction loss during pretraining. These two loss functions are irrelevant to fine-tuning downstream applications, thus the parameters of the employed MLPs in ``MaskLM`` and ``NextSentencePred`` are not updated (staled) when BERT is fine-tuned. To allow parameters with stale gradients, the flag ``ignore_stale_grad=True`` is set in the ``step`` function of ``d2l.train_batch_ch13``. We use this function to train and evaluate the model ``net`` using the training set (``train_iter``) and the testing set (``test_iter``) of SNLI. Due to the limited computational resources, the training and testing accuracy can be further improved: we leave its discussions in the exercises. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python lr, num_epochs = 1e-4, 5 trainer = torch.optim.Adam(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss(reduction='none') net(next(iter(train_iter))[0]) d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.520, train acc 0.791, test acc 0.786 10588.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)] .. figure:: output_natural-language-inference-bert_1857e6_75_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python lr, num_epochs = 1e-4, 5 trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) loss = gluon.loss.SoftmaxCrossEntropyLoss() d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices, d2l.split_batch_multi_inputs) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.477, train acc 0.811, test acc 0.789 4652.5 examples/sec on [gpu(0), gpu(1)] .. figure:: output_natural-language-inference-bert_1857e6_78_1.svg .. raw:: html
.. raw:: html
Summary ------- - We can fine-tune the pretrained BERT model for downstream applications, such as natural language inference on the SNLI dataset. - During fine-tuning, the BERT model becomes part of the model for the downstream application. Parameters that are only related to pretraining loss will not be updated during fine-tuning. Exercises --------- 1. Fine-tune a much larger pretrained BERT model that is about as big as the original BERT base model if your computational resource allows. Set arguments in the ``load_pretrained_model`` function as: replacing ‘bert.small’ with ‘bert.base’, increasing values of ``num_hiddens=256``, ``ffn_num_hiddens=512``, ``num_heads=4``, and ``num_blks=2`` to 768, 3072, 12, and 12, respectively. By increasing fine-tuning epochs (and possibly tuning other hyperparameters), can you get a testing accuracy higher than 0.86? 2. How to truncate a pair of sequences according to their ratio of length? Compare this pair truncation method and the one used in the ``SNLIBERTDataset`` class. What are their pros and cons? .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html