Main Content

Update Batch Normalization Statistics in Custom Training Loop

This example shows how to update the network state in a custom training loop.

A batch normalization layer normalizes each input channel across a mini-batch. To speed up training of convolutional neural networks and reduce the sensitivity to network initialization, use batch normalization layers between convolutional layers and nonlinearities, such as ReLU layers.

During training, batch normalization layers first normalize the activations of each channel by subtracting the mini-batch mean and dividing by the mini-batch standard deviation. Then, the layer shifts the input by a learnable offset β and scales it by a learnable scale factor γ.

When network training finishes, batch normalization layers calculate the mean and variance over the full training set and stores the values in the TrainedMean and TrainedVariance properties. When you use a trained network to make predictions on new images, the batch normalization layers use the trained mean and variance instead of the mini-batch mean and variance to normalize the activations.

To compute the data set statistics, batch normalization layers keep track of the mini-batch statistics by using a continually updating state. If you are implementing a custom training loop, then you must update the network state between mini-batches.

Load Training Data

The digitTrain4DArrayData function loads images of handwritten digits and their digit labels. Create an arrayDatastore object for the images and the angles, and then use the combine function to make a single datastore that contains all of the training data. Extract the class names.

[XTrain,TTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsTTrain = arrayDatastore(TTrain);

dsTrain = combine(dsXTrain,dsTTrain);

classNames = categories(TTrain);
numClasses = numel(classNames);

Define Network

Define the network and specify the average image using the Mean option in the image input layer.

layers = [
    imageInputLayer([28 28 1],Mean=mean(XTrain,4))
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Create a dlnetwork object from the layer array.

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

View the network state. Each batch normalization layer has a TrainedMean parameter and a TrainedVariance parameter containing the data set mean and variance, respectively.

net.State
ans=6×3 table
        Layer            Parameter             Value      
    _____________    _________________    ________________

    "batchnorm_1"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_1"    "TrainedVariance"    {1×1×20 dlarray}
    "batchnorm_2"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_2"    "TrainedVariance"    {1×1×20 dlarray}
    "batchnorm_3"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_3"    "TrainedVariance"    {1×1×20 dlarray}

Define Model Loss Function

Create the function modelLoss, listed at the end of the example, which takes as input a dlnetwork object, and a mini-batch of input data with corresponding labels, and returns the loss, the gradients of the loss with respect to the learnable parameters, and the updated network state.

Specify Training Options

Train for five epochs using a mini-batch size of 128. For the SGDM optimization, specify a learning rate of 0.01 and a momentum of 0.9.

numEpochs = 5;
miniBatchSize = 128;

learnRate = 0.01;
momentum = 0.9;

Train Model

Use minibatchqueue to process and manage the mini-batches of images. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to one-hot encode the class labels.

  • Format the image data with the dimension labels "SSCB" (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" ""]);

Initialize the velocity parameter for the SGDM solver.

velocity = [];

To update the progress bar of the training progress monitor, calculate the total number of training iterations.

numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = ceil(numObservationsTrain/miniBatchSize);
numIterations = numIterationsPerEpoch * numEpochs;

Initialize the TrainingProgressMonitor object.

monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");

Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. At the end of each iteration, display the training progress. For each mini-batch:

  • Evaluate the model loss, gradients, and state using dlfeval and the modelLoss function and update the network state.

  • Update the network parameters using the sgdmupdate function.

iteration = 0;
epoch = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop

        iteration = iteration + 1;

        % Read mini-batch of data and convert the labels to dummy
        % variables.
        [X,T] = next(mbq);

        % Evaluate the model loss, gradients, and state using dlfeval and the
        % modelLoss function and update the network state.
        [loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
        net.State = state;

        % Update the network parameters using the SGDM optimizer.
        [net, velocity] = sgdmupdate(net, gradients, velocity, learnRate, momentum);

        % Update the training progress monitor.
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=(epoch+" of "+numEpochs));
        monitor.Progress = 100*(iteration/numIterations);
    end
end

Test Model

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels. Test the classification accuracy of the model by comparing the predictions on a test set with the true labels and angles.

Load the test data and create a combined datastore containing the images and features.

[XTest,TTest] = digitTest4DArrayData;
dsTest = arrayDatastore(XTest,IterationDimension=4);

Create a minibatchqueue object that processes and manage mini-batches of images during testing. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatchPredictors, defined at the end of this example.

  • By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Format the images with the dimension labels "SSCB" (spatial, spatial, channel, batch).

mbqTest = minibatchqueue(dsTest,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatchPredictors,...
    MiniBatchFormat="SSCB");

Classify the images using the modelPredictions function, listed at the end of the example.

predictions = modelPredictions(net,mbqTest,classNames);

Evaluate the classification accuracy.

accuracy = mean(predictions == TTest)
accuracy = 0.9958

Model Loss Function

The modelLoss function takes as input a dlnetwork object net and a mini-batch of input data X with corresponding labels T, and returns the loss, the gradients of the loss with respect to the learnable parameters in net, and the network state. To compute the gradients automatically, use the dlgradient function.

function [loss,gradients,state] = modelLoss(net,X,T)

[Y,state] = forward(net,X);

loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

Model Predictions Function

The modelPredictions function takes as input a dlnetwork object net, a minibatchqueue of input data mbq, and computes the model predictions by iterating all data in the minibatchqueue. The function uses the onehotdecode function to find the predicted class with the highest score.

function predictions = modelPredictions(net,mbq,classes)

predictions = [];

while hasdata(mbq)
    X = next(mbq);

    % Make predictions using the model function.
    Y = predict(net,X);

    % Determine predicted classes.
    YPredBatch = onehotdecode(Y,classes,1);
    predictions = [predictions; YPredBatch'];
end

end

Mini Batch Preprocessing Function

The preprocessMiniBatch function preprocesses the data using the following steps:

  1. Preprocess the images and features using the preprocessMiniBatchPredictors function.

  2. Extract the label data from the incoming cell array and concatenate along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,T] = preprocessMiniBatch(dataX,dataY)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

% Extract label data from cell and concatenate
T = cat(2,dataY{:});

% One-hot encode labels
T = onehotencode(T,1);

end

Mini-Batch Predictors Preprocessing Function

The preprocessMiniBatchPredictors function preprocesses the predictors by extracting the image data from the incoming cell array and concatenating into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.

function X = preprocessMiniBatchPredictors(dataX)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

end

See Also

| | | | | | | | |

Related Topics