Define Custom Learning Rate Schedule
When you train a neural network using the trainnet
function, the
LearnRateSchedule
argument of
the trainingOptions
function provides several options for customizing the
learning rate schedule. It provides built-in schedules such as
"piecewise"
and "warmup"
. You can also use
learning rate schedule objects that you can customize further such as piecewiseLearnRate
and warmupLearnRate
objects. If these built-in options do not provide the functionality that you need, you can
specify the learning rate schedule as a function of the epoch number using a function
handle.
If you need additional flexibility, for example, you want to use a learning rate schedule that changes the learning rate between iterations, or you want to use a learning rate schedule that requires updating and maintaining a state, then you can define your own custom learning rate schedule object using this example as a guide.
To define a custom learning rate schedule, you can use the template provided in this example, which takes you through these steps:
Name the schedule — Give the schedule a name so that you can use it in MATLAB®.
Declare the schedule properties (optional) — Specify the properties of the schedule.
Create the constructor function (optional) — Specify how to construct the schedule and initialize its properties.
Create the update function — Specify how the schedule calculates the learning rate.
This example shows how to define a time-based decay learning rate schedule and use it to train a neural network. A time-based decay learning rate schedule object updates the learning rate every iteration using a decay rule.
The time-based decay learning rate schedule uses this formula to calculate the learning rate:
where:
k is the iteration number.
α0 is the base learning rate, specified by the
InitialLearnRate
option of thetrainingOptions
function.
Custom Learning Rate Schedule Template
Copy the custom learning rates schedule template into a new file in MATLAB. This template gives the structure of a schedule class definition. It outlines:
The optional
properties
block for the schedule properties.The optional schedule constructor function.
The
update
function.
classdef myLearnRateSchedule < deep.LearnRateSchedule properties % (Optional) Schedule properties. % Declare schedule properties here. end methods function schedule = myLearnRateSchedule() % (Optional) Create a myLearnRateSchedule. % This function must have the same name as the class. % Define schedule constructor function here. end function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,epoch) %UPDATE Update learning rate schedule % Define schedule update function here. end end end
Name Schedule and Specify Superclass
First, give the schedule a name. In the first line of the class file, replace the
existing name myLearnRateSchedule
with
timeBasedDecayLearnRate
.
classdef timeBasedDecayLearnRate < deep.LearnRateSchedule ... end
Next, rename the myLearnRateSchedule
constructor function (the
first function in the methods
section) so that it has the same name
as the schedule.
methods function schedule = timeBasedDecayLearnRate() ... end ... end
Save the Schedule
Save the schedule class file in a new file named
timeBasedDecayLearnRate.m
. The file name must match the
schedule name. To use the schedule, you must save the file in the current folder or
in a folder on the MATLAB path.
Declare Properties
Declare the schedule properties in the properties
section.
By default, custom learning rate schedules have these properties. Do not declare these
properties in the properties
section.
Property | Description |
---|---|
FrequencyUnit | How often the schedule updates the learning rate, specified as
If |
NumSteps | Number of steps the learning rate schedule takes before it is
complete, specified as a positive integer or Inf . For
learning rate schedules that continue indefinitely (also known as
infinite learning rate schedules), this property is
Inf . |
A time-based decay learning rate schedule requires one additional property: the decay
value. Declare the decay value in the properties
block.
properties
% Schedule properties
Decay
end
Create Constructor Function
Create the function that constructs the schedule and initializes the schedule properties. Specify any variables required to create the schedule as inputs to the constructor function.
The time-based decay learning rate schedule constructor function requires one argument
(the decay). Specify one input argument named decay
in the
timeBasedDecayLearnRate
function that corresponds to the decay.
Add a comment to the top of the function that explains the syntax of the
function.
function schedule = timeBasedDecayLearnRate(decay) % timeBasedDecayLearnRate Time-based decay learning rate % schedule % schedule = timeBasedDecayLearnRate(decay) creates a % time-based decay learning rate schedule with the specified % decay. ... end
Initialize Schedule Properties
Initialize the schedule properties in the constructor function. Replace the
comment % Define schedule constructor function here
with code
that initializes the schedule properties.
Because the time-based decay learning rate schedule updates the learning rate each iteration, set the
FrequencyUnit
property to"iteration"
.Because the time-based decay learning rate schedule is infinite, set the
NumSteps
property toInf
.Set the schedule
Decay
property to thedecay
argument.
% Set schedule properties. schedule.FrequencyUnit = "iteration"; schedule.NumSteps = Inf; schedule.Decay = decay;
View the completed constructor function.
function schedule = timeBasedDecayLearnRate(decay)
% timeBasedDecayLearnRate Time-based decay learning rate
% schedule
% schedule = timeBasedDecayLearnRate(decay) creates a
% time-based decay learning rate schedule with the specified
% decay.
% Set schedule properties.
schedule.FrequencyUnit = "iteration";
schedule.NumSteps = Inf;
schedule.Decay = decay;
end
With this constructor function, the command
timeBasedDecayLearnRate(0.01)
creates a time-based decay
learning rate schedule with a decay value of 0.01
.
Create Update Function
Create the function that updates the learning rate.
Create a function named update
that updates the learning rate
schedule properties and also returns the calculated learning rate value.
The update
function has the syntax [schedule,learnRate]
= update(schedule,initialLearnRate,iteration,epoch)
, where:
schedule
is an instance of the learning rate schedule.learnRate
is the calculated learning rate value.initialLearnRate
is the initial learning rate.iteration
is the iteration number.epoch
is the epoch number.
The time-based decay learning rate schedule uses this formula to calculate the learning rate:
where:
k is the iteration number.
α0 is the base learning rate, specified by the
InitialLearnRate
option of thetrainingOptions
function.
Implement this operation in update
. The schedule does not require
updating any state values, so the output schedule is unchanged.
Because a time-based decay learning rate schedule does not require the epoch number,
the syntax for update
for the schedule is
[schedule,learnRate] =
update(schedule,initialLearnRate,iteration,~)
. Because the time-based
decay learning rate schedule is not finite, there is no need to update the
IsDone
property.
function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% UPDATE Update learning rate schedule
% [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% calculates the learning rate for the specified iteration
% and also returns the updated schedule object.
% Calculate learning rate.
decay = schedule.Decay;
learnRate = initialLearnRate / (1 + decay*(iteration-1));
end
Completed Learning Rate Schedule
Vie the completed learning rate schedule class file.
classdef timeBasedDecayLearnRate < deep.LearnRateSchedule
% timeBasedDecayLearnRate Time-based decay learning rate schedule
properties
% Schedule properties
Decay
end
methods
function schedule = timeBasedDecayLearnRate(decay)
% timeBasedDecayLearnRate Time-based decay learning rate
% schedule
% schedule = timeBasedDecayLearnRate(decay) creates a
% time-based decay learning rate schedule with the specified
% decay.
% Set schedule properties.
schedule.FrequencyUnit = "iteration";
schedule.NumSteps = Inf;
schedule.Decay = decay;
end
function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% UPDATE Update learning rate schedule
% [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% calculates the learning rate for the specified iteration
% and also returns the updated schedule object.
% Calculate learning rate.
decay = schedule.Decay;
learnRate = initialLearnRate / (1 + decay*(iteration-1));
end
end
end
Train Using Custom Learning Rate Schedule Object
You can use a custom learning rate schedule object in the same way as any other learning rate schedule object in the trainingOptions
function. This example shows how to create and train a network for digit classification using a time-based decay learning rate schedule object you defined earlier.
Load the example training data.
load DigitsDataTrain
Create a layer array.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(5,20) batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer];
Create an instance of a time-based decay learning rate schedule object with a decay value of 0.01.
schedule = timeBasedDecayLearnRate(0.01)
schedule = timeBasedDecayLearnRate with properties: Decay: 0.0100 FrequencyUnit: "iteration" NumSteps: Inf
Specify the training options. To train using the learning rate schedule object, set the LearnRateSchedule
training option to the object.
options = trainingOptions("sgdm", ... MaxEpochs=10, ... LearnRateSchedule=schedule, ... Metrics="accuracy");
Train the neural network using the trainnet
function. For classification, use index cross-entropy loss. By default, the trainnet
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option.
To get the information about training, such as the learning rate value for each iteration, use the second output of the trainnet
function.
[net,info] = trainnet(XTrain,labelsTrain,layers,"index-crossentropy",options);
Iteration Epoch TimeElapsed LearnRate TrainingLoss TrainingAccuracy _________ _____ ___________ _________ ____________ ________________ 1 1 00:00:02 0.01 2.5434 10.938 50 2 00:00:04 0.0067114 0.36849 89.062 100 3 00:00:06 0.0050251 0.18073 93.75 150 4 00:00:08 0.0040161 0.093009 97.656 200 6 00:00:10 0.0033445 0.079959 99.219 250 7 00:00:12 0.0028653 0.062385 99.219 300 8 00:00:14 0.0025063 0.033808 100 350 9 00:00:16 0.0022272 0.0498 100 390 10 00:00:17 0.002045 0.047401 100 Training stopped: Max epochs completed
Extract the learning rate information from the training information and visualize it in a plot.
figure plot(info.TrainingHistory.LearnRate) ylim([0 inf]) xlabel("Iteration") ylabel("Learning Rate")
Test the neural network using the testnet
function. For single-label classification, evaluate the accuracy. The accuracy is the percentage of correct predictions. By default, the testnet
function uses a GPU if one is available. To select the execution environment manually, use the ExecutionEnvironment
argument of the testnet
function.
load DigitsDataTest classNames = categories(labelsTest); accuracy = testnet(net,XTest,labelsTest,"accuracy")
accuracy = 97.7000
See Also
trainingOptions
| trainnet
| dlnetwork
| piecewiseLearnRate
| warmupLearnRate
| polynomialLearnRate
| exponentialLearnRate
| cosineLearnRate
| cyclicalLearnRate