Overfitting deep neural network
Show older comments
I am using CNN architecture resnet18 with transfer learning for classifications. Overfitting is heppenrd after trainging and testing the model.
Here is my code. Can anyone please tell me what chanfes I have to do in the below code. Please see the attached result file in which you can see the data overfitting is happening.
clear all
close all
imds = imageDatastore("D:\DatasetJPG", ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7); %70% for train 30% for test
net=resnet18; % for the first time,you have to download the package from Add-on explorer
%Replace Final Layers
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(numClasses,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc1000' ,newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassLayer);
%Train Network
inputSize = net.Layers(1).InputSize;
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-5,5], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',20, ...
'InitialLearnRate',1e-3, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',5, ...
'Verbose',false, ...
'Plots','training-progress');
trainedNet = trainNetwork(augimdsTrain,lgraph,options);
YPred = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
C = confusionmat(imdsValidation.Labels,YPred)
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!