Main Content

Train Network Using Cyclical Learning Rate for Snapshot Ensembling

This example shows how to train a network to classify images of objects using a cyclical learning rate schedule and snapshot ensembling for better test accuracy. In the example, you learn how to use a cosine function for the learning rate schedule, take snapshots of the network during training to create a model ensemble, and add L2-norm regularization (weight decay) to the training loss.

This example trains a residual network [1] on the CIFAR-10 data set [2] with a custom cyclical learning rate: for each iteration, the solver uses the learning rate given by a shifted cosine function [3] alpha(t) = (alpha0/2)*cos(pi*mod(t-1,T/M)/(T/M)+1), where t is the iteration number, T is the total number of training iterations, alpha0 is the initial learning rate, and M is the number of cycles/snapshots. This learning rate schedule effectively splits the training process into M cycles. Each cycle begins with a large learning rate that decays monotonically, forcing the network to explore different local minima. At the end of each training cycle, you take a snapshot of the network (that is, you save the model at this iteration) and later average the predictions of all the snapshot models, also known as snapshot ensembling [4], to improve the final test accuracy.

Prepare Data

Download the CIFAR-10 data set [2]. The data set contains 60,000 images. Each image is 32-by-32 in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.

datadir = tempdir; 
downloadCIFARData(datadir);

Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images.

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);
classes = categories(YTrain);
numClasses = numel(classes);

You can display a random sample of the training images using the following code.

figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),'ThumbnailSize',[96,96]);
imshow(im)

Create an augmentedImageDatastore object to use for network training. During training, the datastore randomly flips the training images along the vertical axis and randomly translates them up to four pixels horizontally and vertically. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter);

Define Network Architecture

Create a residual network [1] with six standard convolutional units (two units per stage) and a width of 16. The total network depth is 2*6+2 = 14. In addition, specify the average image using the 'Mean' option in the image input layer.

netWidth = 16;
layers = [
    imageInputLayer(imageSize,'Name','input','Mean', mean(XTrain,4))
    convolution2dLayer(3,netWidth,'Padding','same','Name','convInp')
    batchNormalizationLayer('Name','BNInp')
    reluLayer('Name','reluInp')
    
    convolutionalUnit(netWidth,1,'S1U1')
    additionLayer(2,'Name','add11')
    reluLayer('Name','relu11')
    convolutionalUnit(netWidth,1,'S1U2')
    additionLayer(2,'Name','add12')
    reluLayer('Name','relu12')
    
    convolutionalUnit(2*netWidth,2,'S2U1')
    additionLayer(2,'Name','add21')
    reluLayer('Name','relu21')
    convolutionalUnit(2*netWidth,1,'S2U2')
    additionLayer(2,'Name','add22')
    reluLayer('Name','relu22')
    
    convolutionalUnit(4*netWidth,2,'S3U1')
    additionLayer(2,'Name','add31')
    reluLayer('Name','relu31')
    convolutionalUnit(4*netWidth,1,'S3U2')
    additionLayer(2,'Name','add32')
    reluLayer('Name','relu32')
    
    averagePooling2dLayer(8,'Name','globalPool')
    fullyConnectedLayer(10,'Name','fcFinal')
    ];

lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'reluInp','add11/in2');
lgraph = connectLayers(lgraph,'relu11','add12/in2');
skip1 = [
    convolution2dLayer(1,2*netWidth,'Stride',2,'Name','skipConv1')
    batchNormalizationLayer('Name','skipBN1')];
lgraph = addLayers(lgraph,skip1);
lgraph = connectLayers(lgraph,'relu12','skipConv1');
lgraph = connectLayers(lgraph,'skipBN1','add21/in2');
lgraph = connectLayers(lgraph,'relu21','add22/in2');
skip2 = [
    convolution2dLayer(1,4*netWidth,'Stride',2,'Name','skipConv2')
    batchNormalizationLayer('Name','skipBN2')];
lgraph = addLayers(lgraph,skip2);
lgraph = connectLayers(lgraph,'relu22','skipConv2');
lgraph = connectLayers(lgraph,'skipBN2','add31/in2');
lgraph = connectLayers(lgraph,'relu31','add32/in2');

Plot the ResNet architecture.

figure;
plot(lgraph)

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph);

Define Model Gradients Function

Create the helper function modelGradients, listed at the end of the example. The function takes in a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the gradients of the loss with respect to the learnable parameters in dlnet. This function also returns the loss and the state of the nonlearnable parameters of the network at a given iteration.

Specify Training Options

Specify the training options. Train for 200 epochs with a mini-batch size of 64.

numEpochs = 200;
miniBatchSize = 64;

numObservations = numel(YTrain);

velocity = [];
momentum = 0.9;
weightDecay = 1e-4;

Specify the training options specific to the cyclical learning rate. Alpha0 is the initial learning rate and numSnapshots is the number of cycles or snapshots taken during training.

alpha0 = 0.1;
numSnapshots = 5;
epochsPerSnapshot = numEpochs./numSnapshots; 
iterationsPerSnapshot = ceil(numObservations./miniBatchSize)*numEpochs./numSnapshots;
modelPrefix = "SnapshotEpoch";

Visualize the training progress in a plot.

plots = "training-progress";

Initialize the training figure.

if plots == "training-progress"
    [lossLine,learnRateLine] = plotLossAndLearnRate();
end

Train Model

Use minibatchqueue to process and manage mini-batches of images during training. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to one-hot encode the class labels.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue onbject converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

augimdsTrain.MiniBatchSize = miniBatchSize;

mbqTrain = minibatchqueue(augimdsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Train the model using a custom training loop.For each epoch, shuffle the datastore, loop over mini-batches of data, and save the model (snapshot) if the current epoch is a multiple of epochsPerSnapshot. At the end of each epoch, display the training progress. For each mini-batch:

  • Evaluate the model gradients and loss using dlfeval and the modelGradients function.

  • Update the state of the nonlearnable parameters of the network.

  • Determine the learning rate for the cyclical learning rate schedule.

  • Update the network parameters using the sgdmupdate function.

  • Plot the loss and learning rate at each iteration.

For this example, the training took approximately 14 hours on a NVIDIA™ TITAN RTX.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
  
    % Shuffle data.
    shuffle(mbqTrain);
    
    % Save snapshot model.
    if ~mod(epoch,epochsPerSnapshot)
        save(modelPrefix + epoch + ".mat",'dlnet');
    end
    
    % Loop over mini-batches.
    while hasdata(mbqTrain)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX,dlY] = next(mbqTrain);
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradients, loss, state] = dlfeval(@modelGradients,dlnet,dlX,dlY,weightDecay);
        
        % Update the state of nonlearnable parameters.
        dlnet.State = state;
        
        % Determine learning rate for cyclical learning rate schedule.
        learnRate = 0.5*alpha0*(cos((pi*mod(iteration-1,iterationsPerSnapshot)./iterationsPerSnapshot))+1);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum);
        
        % Display the training progress.
        if plots == "training-progress"
                D = duration(0,0,toc(start),'Format','hh:mm:ss');
                addpoints(lossLine,iteration,double(gather(extractdata(loss))))
                addpoints(learnRateLine, iteration, learnRate);            
                sgtitle("Epoch: " + epoch + ", Elapsed: " + string(D))
                drawnow
        end
        
    end
    
end

Create Snapshot Ensemble and Test Model

Combine the M snapshots of the network taken during training to form a final ensemble and test the classification accuracy of the model. The ensemble predictions correspond to the average of the output of the fully connected layer from all M individual models.

Test the model on the test data provided with the CIFAR-10 data set. Manage the test data set using a minibatchqueue object with the same setting as the training data.

augimdsTest = augmentedImageDatastore(imageSize,XTest,YTest);
augimdsTest.MiniBatchSize = miniBatchSize;

mbqTest = minibatchqueue(augimdsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Evaluate the accuracy of each snapshot network. Use the modelPredictions function defined at the end of this example to iterate over all the data in the test data set. The function returns the output of the fully connected layer from the model, the predicted classes, and the comparison with the true class.

modelName = cell(numSnapshots+1,1);
fcOutput = zeros(numClasses,numel(YTest),numSnapshots+1);
classPredictions = cell(1,numSnapshots+1);
modelAccuracy = zeros(numSnapshots+1,1);

for m = 1:numSnapshots
    modelName{m} = modelPrefix + m*epochsPerSnapshot;
    load(modelName{m} + ".mat");
    
    reset(mbqTest);
    [fcOutputTest,classPredTest,classCorrTest] = modelPredictions(dlnet,mbqTest,classes);
    
    fcOutput(:,:,m) = fcOutputTest;    
    classPredictions{m} = classPredTest;
    modelAccuracy(m) = 100*mean(classCorrTest);
    
    disp(modelName{m} + " accuracy: " + modelAccuracy(m) + "%")
end
SnapshotEpoch40 accuracy: 88.35%
SnapshotEpoch80 accuracy: 89.93%
SnapshotEpoch120 accuracy: 90.51%
SnapshotEpoch160 accuracy: 90.33%
SnapshotEpoch200 accuracy: 90.63%

To determine the output of the ensemble networks, compute the average of the fully connected output of each snapshot network. Find the predicted classes from the ensemble network using the onehotdecode function. Compare with the true classes to evaluate the accuracy of the ensemble.

fcOutput(:,:,end) = mean(fcOutput(:,:,1:end-1),3);
classPredictions{end} = onehotdecode(softmax(fcOutput(:,:,end)),classes,1,'categorical');

classCorrEnsemble = classPredictions{end} == YTest';
modelAccuracy(end) = 100*mean(classCorrEnsemble);

modelName{end} = "Ensemble model";
disp("Ensemble accuracy: " + modelAccuracy(end) + "%")
Ensemble accuracy: 91.59%

Plot Accuracy

Plot the accuracy on the test data set for all snapshot models and the ensemble model.

figure;bar(modelAccuracy);
ylabel('Accuracy (%)');
xticklabels(modelName)
xtickangle(45)
title('Model accuracy')

Helper Functions

Model Gradients Function

The modelGradients function takes in a dlnetwork object dlnet, a mini-batch of input data dlX, the labels Y, and the parameter for weight decay. The function returns the gradients, the loss, and the state of the nonlearnable parameters. To compute the gradients automatically, use the dlgradient function.

function [gradients,loss,state] = modelGradients(dlnet,dlX,Y,weightDecay)
    
    [dlYPred,state] = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    loss = crossentropy(dlYPred, Y);
    
    % L2-regularization (weight decay)
    allParams = dlnet.Learnables(dlnet.Learnables.Parameter == "Weights" | dlnet.Learnables.Parameter == "Scale",:).Value;
    l2Norm = cellfun(@(x) sum(x.^2,'All'),allParams,'UniformOutput',false);
    l2Norm = sum(cat(1,l2Norm{:}));
    loss = loss + weightDecay*0.5*l2Norm;
    
    gradients = dlgradient(loss,dlnet.Learnables);
end

Model Predictions Function

The modelPredictions function takes as input a dlnetwork object dlnet, a minibatchqueue of input data mbq, and computes the model predictions by iterating over all data in the minibatchqueue. The function uses the onehotdecode function to find the predicted class with the highest score and then compares the prediction with the true class. The function returns the network output, the class predictions, and a vector of ones and zeros that represents correct and incorrect predictions.

function [rawPredictions,classPredictions,classCorr] = modelPredictions(dlnet,mbq,classes)
    rawPredictions = [];
    classPredictions = [];
    classCorr = [];
    
    while hasdata(mbq)
        [dlX,dlY] = next(mbq);
        
        % Make predictions
        dlYPred = predict(dlnet,dlX);
        rawPredictions = [rawPredictions extractdata(gather(dlYPred))];
        
        % Convert network output to probabilities and determine predicted
        % classes
        dlYPred = softmax(dlYPred);
        YPredBatch = onehotdecode(dlYPred,classes,1);
        classPredictions = [classPredictions YPredBatch];
        
        % Compare predicted and true classes
        Y = onehotdecode(dlY,classes,1);
        classCorr = [classCorr YPredBatch == Y];
    end
    
end

Plot Loss and Learning Rate Function

The plotLossAndLearnRate function initiliaizes the plots for displaying the loss and learning rate at each iteration during training.

function [lossLine, learnRateLine] = plotLossAndLearnRate()
    figure
    
    subplot(2,1,1); 
    lossLine = animatedline('Color',[0.85 0.325 0.098]);
    title('Loss');
    xlabel('Iteration')
    ylabel('Loss')
    grid on
    
    subplot(2,1,2); 
    learnRateLine = animatedline('Color',[0 0.447 0.741]);
    title('Learning rate');
    xlabel('Iteration')
    ylabel('Learning rate')
    grid on
end

Convolutional Unit Function

The convolutionalUnit(numF,stride,tag) function creates an array of layers with two convolutional layers and corresponding batch normalization and ReLU layers. numF is the number of convolutional filters, stride is the stride of the first convolutional layer, and tag is a tag that is prepended to all layer names.

function layers = convolutionalUnit(numF,stride,tag)
    layers = [
        convolution2dLayer(3,numF,'Padding','same','Stride',stride,'Name',[tag,'conv1'])
        batchNormalizationLayer('Name',[tag,'BN1'])
        reluLayer('Name',[tag,'relu1'])
        convolution2dLayer(3,numF,'Padding','same','Name',[tag,'conv2'])
        batchNormalizationLayer('Name',[tag,'BN2'])];
end

Data Preprocessing Function

The preprocessMiniBatch function preprocesses the data using the following steps:

  1. Extract the image data from the incoming cell array and concatenate into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.

  2. Extract the label data from the incoming cell arrays and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,Y] = preprocessMiniBatch(XCell,YCell)
    % Extract image data from cell and concatenate
    X = cat(4,XCell{:});
    % Extract label data from cell and concatenate
    Y = cat(2,YCell{:});
    
    % One-hot encode labels
    Y = onehotencode(Y,1);
end

References

[1] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.

[2] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

[3] Loshchilov, Ilya, and Frank Hutter. "Sgdr: Stochastic gradient descent with warm restarts." (2016). arXiv preprint arXiv:1608.03983.

[4] Huang, Gao, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger. "Snapshot ensembles: Train 1, get m for free." (2017). arXiv preprint arXiv:1704.00109.

See Also

| | | | | | | | |

Related Topics