.. _sec_nin: Network in Network (NiN) ======================== LeNet, AlexNet, and VGG all share a common design pattern: extract features exploiting *spatial* structure via a sequence of convolutions and pooling layers and post-process the representations via fully connected layers. The improvements upon LeNet by AlexNet and VGG mainly lie in how these later networks widen and deepen these two modules. This design poses two major challenges. First, the fully connected layers at the end of the architecture consume tremendous numbers of parameters. For instance, even a simple model such as VGG-11 requires a monstrous matrix, occupying almost 400MB of RAM in single precision (FP32). This is a significant impediment to computation, in particular on mobile and embedded devices. After all, even high-end mobile phones sport no more than 8GB of RAM. At the time VGG was invented, this was an order of magnitude less (the iPhone 4S had 512MB). As such, it would have been difficult to justify spending the majority of memory on an image classifier. Second, it is equally impossible to add fully connected layers earlier in the network to increase the degree of nonlinearity: doing so would destroy the spatial structure and require potentially even more memory. The *network in network* (*NiN*) blocks :cite:`Lin.Chen.Yan.2013` offer an alternative, capable of solving both problems in one simple strategy. They were proposed based on a very simple insight: (i) use :math:`1 \times 1` convolutions to add local nonlinearities across the channel activations and (ii) use global average pooling to integrate across all locations in the last representation layer. Note that global average pooling would not be effective, were it not for the added nonlinearities. Let’s dive into this in detail. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import init, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
NiN Blocks ---------- Recall :numref:`subsec_1x1`. In it we said that the inputs and outputs of convolutional layers consist of four-dimensional tensors with axes corresponding to the example, channel, height, and width. Also recall that the inputs and outputs of fully connected layers are typically two-dimensional tensors corresponding to the example and feature. The idea behind NiN is to apply a fully connected layer at each pixel location (for each height and width). The resulting :math:`1 \times 1` convolution can be thought of as a fully connected layer acting independently on each pixel location. :numref:`fig_nin` illustrates the main structural differences between VGG and NiN, and their blocks. Note both the difference in the NiN blocks (the initial convolution is followed by :math:`1 \times 1` convolutions, whereas VGG retains :math:`3 \times 3` convolutions) and at the end where we no longer require a giant fully connected layer. .. _fig_nin: .. figure:: ../img/nin.svg :width: 600px Comparing the architectures of VGG and NiN, and of their blocks. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def nin_block(out_channels, kernel_size, strides, padding): return nn.Sequential( nn.LazyConv2d(out_channels, kernel_size, strides, padding), nn.ReLU(), nn.LazyConv2d(out_channels, kernel_size=1), nn.ReLU(), nn.LazyConv2d(out_channels, kernel_size=1), nn.ReLU()) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def nin_block(num_channels, kernel_size, strides, padding): blk = nn.Sequential() blk.add(nn.Conv2D(num_channels, kernel_size, strides, padding, activation='relu'), nn.Conv2D(num_channels, kernel_size=1, activation='relu'), nn.Conv2D(num_channels, kernel_size=1, activation='relu')) return blk .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def nin_block(out_channels, kernel_size, strides, padding): return nn.Sequential([ nn.Conv(out_channels, kernel_size, strides, padding), nn.relu, nn.Conv(out_channels, kernel_size=(1, 1)), nn.relu, nn.Conv(out_channels, kernel_size=(1, 1)), nn.relu]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def nin_block(out_channels, kernel_size, strides, padding): return tf.keras.models.Sequential([ tf.keras.layers.Conv2D(out_channels, kernel_size, strides=strides, padding=padding), tf.keras.layers.Activation('relu'), tf.keras.layers.Conv2D(out_channels, 1), tf.keras.layers.Activation('relu'), tf.keras.layers.Conv2D(out_channels, 1), tf.keras.layers.Activation('relu')]) .. raw:: html
.. raw:: html
NiN Model --------- NiN uses the same initial convolution sizes as AlexNet (it was proposed shortly thereafter). The kernel sizes are :math:`11\times 11`, :math:`5\times 5`, and :math:`3\times 3`, respectively, and the numbers of output channels match those of AlexNet. Each NiN block is followed by a max-pooling layer with a stride of 2 and a window shape of :math:`3\times 3`. The second significant difference between NiN and both AlexNet and VGG is that NiN avoids fully connected layers altogether. Instead, NiN uses a NiN block with a number of output channels equal to the number of label classes, followed by a *global* average pooling layer, yielding a vector of logits. This design significantly reduces the number of required model parameters, albeit at the expense of a potential increase in training time. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class NiN(d2l.Classifier): def __init__(self, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() self.net = nn.Sequential( nin_block(96, kernel_size=11, strides=4, padding=0), nn.MaxPool2d(3, stride=2), nin_block(256, kernel_size=5, strides=1, padding=2), nn.MaxPool2d(3, stride=2), nin_block(384, kernel_size=3, strides=1, padding=1), nn.MaxPool2d(3, stride=2), nn.Dropout(0.5), nin_block(num_classes, kernel_size=3, strides=1, padding=1), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()) self.net.apply(d2l.init_cnn) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class NiN(d2l.Classifier): def __init__(self, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() self.net = nn.Sequential() self.net.add( nin_block(96, kernel_size=11, strides=4, padding=0), nn.MaxPool2D(pool_size=3, strides=2), nin_block(256, kernel_size=5, strides=1, padding=2), nn.MaxPool2D(pool_size=3, strides=2), nin_block(384, kernel_size=3, strides=1, padding=1), nn.MaxPool2D(pool_size=3, strides=2), nn.Dropout(0.5), nin_block(num_classes, kernel_size=3, strides=1, padding=1), nn.GlobalAvgPool2D(), nn.Flatten()) self.net.initialize(init.Xavier()) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class NiN(d2l.Classifier): lr: float = 0.1 num_classes = 10 training: bool = True def setup(self): self.net = nn.Sequential([ nin_block(96, kernel_size=(11, 11), strides=(4, 4), padding=(0, 0)), lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)), nin_block(256, kernel_size=(5, 5), strides=(1, 1), padding=(2, 2)), lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)), nin_block(384, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1)), lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)), nn.Dropout(0.5, deterministic=not self.training), nin_block(self.num_classes, kernel_size=(3, 3), strides=1, padding=(1, 1)), lambda x: nn.avg_pool(x, (5, 5)), # global avg pooling lambda x: x.reshape((x.shape[0], -1)) # flatten ]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class NiN(d2l.Classifier): def __init__(self, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() self.net = tf.keras.models.Sequential([ nin_block(96, kernel_size=11, strides=4, padding='valid'), tf.keras.layers.MaxPool2D(pool_size=3, strides=2), nin_block(256, kernel_size=5, strides=1, padding='same'), tf.keras.layers.MaxPool2D(pool_size=3, strides=2), nin_block(384, kernel_size=3, strides=1, padding='same'), tf.keras.layers.MaxPool2D(pool_size=3, strides=2), tf.keras.layers.Dropout(0.5), nin_block(num_classes, kernel_size=3, strides=1, padding='same'), tf.keras.layers.GlobalAvgPool2D(), tf.keras.layers.Flatten()]) .. raw:: html
.. raw:: html
We create a data example to see the output shape of each block. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python NiN().layer_summary((1, 1, 224, 224)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: torch.Size([1, 96, 54, 54]) MaxPool2d output shape: torch.Size([1, 96, 26, 26]) Sequential output shape: torch.Size([1, 256, 26, 26]) MaxPool2d output shape: torch.Size([1, 256, 12, 12]) Sequential output shape: torch.Size([1, 384, 12, 12]) MaxPool2d output shape: torch.Size([1, 384, 5, 5]) Dropout output shape: torch.Size([1, 384, 5, 5]) Sequential output shape: torch.Size([1, 10, 5, 5]) AdaptiveAvgPool2d output shape: torch.Size([1, 10, 1, 1]) Flatten output shape: torch.Size([1, 10]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python NiN().layer_summary((1, 1, 224, 224)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 96, 54, 54) MaxPool2D output shape: (1, 96, 26, 26) Sequential output shape: (1, 256, 26, 26) MaxPool2D output shape: (1, 256, 12, 12) Sequential output shape: (1, 384, 12, 12) MaxPool2D output shape: (1, 384, 5, 5) Dropout output shape: (1, 384, 5, 5) Sequential output shape: (1, 10, 5, 5) GlobalAvgPool2D output shape: (1, 10, 1, 1) Flatten output shape: (1, 10) [22:45:22] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python NiN(training=False).layer_summary((1, 224, 224, 1)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 54, 54, 96) function output shape: (1, 26, 26, 96) Sequential output shape: (1, 26, 26, 256) function output shape: (1, 12, 12, 256) Sequential output shape: (1, 12, 12, 384) function output shape: (1, 5, 5, 384) Dropout output shape: (1, 5, 5, 384) Sequential output shape: (1, 5, 5, 10) function output shape: (1, 1, 1, 10) function output shape: (1, 10) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python NiN().layer_summary((1, 224, 224, 1)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 54, 54, 96) MaxPooling2D output shape: (1, 26, 26, 96) Sequential output shape: (1, 26, 26, 256) MaxPooling2D output shape: (1, 12, 12, 256) Sequential output shape: (1, 12, 12, 384) MaxPooling2D output shape: (1, 5, 5, 384) Dropout output shape: (1, 5, 5, 384) Sequential output shape: (1, 5, 5, 10) GlobalAveragePooling2D output shape: (1, 10) Flatten output shape: (1, 10) .. raw:: html
.. raw:: html
Training -------- As before we use Fashion-MNIST to train the model using the same optimizer that we used for AlexNet and VGG. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = NiN(lr=0.05) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn) trainer.fit(model, data) .. figure:: output_nin_8ad4f3_63_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = NiN(lr=0.05) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) trainer.fit(model, data) .. figure:: output_nin_8ad4f3_66_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = NiN(lr=0.05) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) trainer.fit(model, data) .. figure:: output_nin_8ad4f3_69_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python trainer = d2l.Trainer(max_epochs=10) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) with d2l.try_gpu(): model = NiN(lr=0.05) trainer.fit(model, data) .. figure:: output_nin_8ad4f3_72_0.svg .. raw:: html
.. raw:: html
Summary ------- NiN has dramatically fewer parameters than AlexNet and VGG. This stems primarily from the fact that it needs no giant fully connected layers. Instead, it uses global average pooling to aggregate across all image locations after the last stage of the network body. This obviates the need for expensive (learned) reduction operations and replaces them by a simple average. What surprised researchers at the time was the fact that this averaging operation did not harm accuracy. Note that averaging across a low-resolution representation (with many channels) also adds to the amount of translation invariance that the network can handle. Choosing fewer convolutions with wide kernels and replacing them by :math:`1 \times 1` convolutions aids the quest for fewer parameters further. It can cater for a significant amount of nonlinearity across channels within any given location. Both :math:`1 \times 1` convolutions and global average pooling significantly influenced subsequent CNN designs. Exercises --------- 1. Why are there two :math:`1\times 1` convolutional layers per NiN block? Increase their number to three. Reduce their number to one. What changes? 2. What changes if you replace the :math:`1 \times 1` convolutions by :math:`3 \times 3` convolutions? 3. What happens if you replace the global average pooling by a fully connected layer (speed, accuracy, number of parameters)? 4. Calculate the resource usage for NiN. 1. What is the number of parameters? 2. What is the amount of computation? 3. What is the amount of memory needed during training? 4. What is the amount of memory needed during prediction? 5. What are possible problems with reducing the :math:`384 \times 5 \times 5` representation to a :math:`10 \times 5 \times 5` representation in one step? 6. Use the structural design decisions in VGG that led to VGG-11, VGG-16, and VGG-19 to design a family of NiN-like networks. .. raw:: html
pytorchmxnetjax
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html