Clear Filters
Clear Filters

How to achieve transfer learning process for YOLOv3 in matlab r2021a ?

6 views (last 30 days)
I am trying to achieve a transfer learning process for YOLO v3.
However, according to the transfer learning example:
pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange, ...
'RandXScale',scaleRange, ...
'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',miniBatchSize, ...
'MaxEpochs',6, ...
'InitialLearnRate',3e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',valFrequency, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options);
The input network should be a complete netwwork with single output layer, which is not the case for yolo v3.
YOLO v3 has multiple outputs, and it is not classification problem with fully connected layer, the last learnable layer are several 1-by-1 convolutional layers.
I tried to replace the two convolution 2D layer with new layers and use this new network as basenet to train it on a new dataset, but it seems train all the layer again rather than only the last two layers.
%pre-trained network same architecture as in the yolov3 example
load('yolov3_car.mat');
inputSize = net.Layers(1).InputSize;
if isa(net,'SeriesNetwork')
lgraph = layerGraph(net.Layers);
else
lgraph = layerGraph(net);
end
classNames = trainingDataTbl.Properties.VariableNames(2:end);
numClasses = size(classNames, 2);
newLearnableLayer1 = convolution2dLayer(1,numClasses, ...
'Name','new_conv1', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
newLearnableLayer2 = convolution2dLayer(1,numClasses, ...
'Name','new_conv2', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'conv2Detection1',newLearnableLayer1);
lgraph = replaceLayer(lgraph,'conv2Detection2',newLearnableLayer2);
figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])
layers = lgraph.Layers;
connections = lgraph.Connections;
%layer
layers(1:33) = freezeWeights(layers(1:33));
lgraph = createLgraphUsingConnections(layers,connections);
analyzeNetwork(lgraph)
net1 = dlnetwork(lgraph);
yolov3Detector = yolov3ObjectDetector(net1, classNames, anchorBoxes, 'DetectionNetworkSource', {'fire9-concat', 'fire5-concat'});
preprocessedTrainingData = transform(augmentedTrainingData, @(data)preprocess(yolov3Detector, data));
numEpochs = 6;
miniBatchSize = 8;
learningRate = 0.001;
warmupPeriod = 1000;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
velocity = [];
if canUseParallelPool
dispatchInBackground = true;
else
dispatchInBackground = false;
end
mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
"MiniBatchSize", miniBatchSize,...
"MiniBatchFcn", @(images, boxes, labels) createBatchData(images, boxes, labels, classNames), ...
"MiniBatchFormat", ["SSCB", ""],...
"DispatchInBackground", dispatchInBackground,...
"OutputCast", ["", "double"]);
if doTraining
% Create subplots for the learning rate and mini-batch loss.
fig = figure;
[lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(fig);
iteration = 0;
% Custom training loop.
for epoch = 1:numEpochs
reset(mbqTrain);
shuffle(mbqTrain);
while(hasdata(mbqTrain))
iteration = iteration + 1;
[XTrain, YTrain] = next(mbqTrain);
% Evaluate the model gradients and loss using dlfeval and the
% modelGradients function.
[gradients, state, lossInfo] = dlfeval(@modelGradients, yolov3Detector, XTrain, YTrain, penaltyThreshold);
% Apply L2 regularization.
gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, yolov3Detector.Learnables);
% Determine the current learning rate value.
currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
% Update the detector learnable parameters using the SGDM optimizer.
[yolov3Detector.Learnables, velocity] = sgdmupdate(yolov3Detector.Learnables, gradients, velocity, currentLR);
% Update the state parameters of dlnetwork.
yolov3Detector.State = state;
% Display progress.
displayLossInfo(epoch, iteration, currentLR, lossInfo);
% Update training plot with new points.
updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
end
end
else
yolov3Detector = preTrainedDetector;
end
So what should I do if I want to freeze other layers only train for the last learnable layers for YOLO v3?

Answers (1)

Davide Fantin
Davide Fantin on 26 May 2021
In order to "freeze" layers, i.e. not update the learnable parameters, you need to set the learn rate factor for those layers to 0 (before the training loop).
You can achieve this through the setLearnRateFactor function available for dlnetwork objects. The setLearnRateFactor can take in input layers or an entire dlnetwork. You can find example of usage of this function in the documentation page: https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.layer.setlearnratefactor.html
Hope this helps!

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!