How do I debug a convolutional neural network with a custom training loop that is not learning?
12 views (last 30 days)
Show older comments
Hello! I have been trying to design a CNN for image analysis. The CNN is training on simulated images of size 132 x 132 x 6 (spatial, spatial, channel). The simulated images are computed using a bi-exponential equation of the form . In the CNN, the input images are forward passed through the network to generate four feature maps (, , , and ) which then are scaled and used to calculated the predicted image signals, . The predicted image signals are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
Also, I have removed parts of the code to make it as simple as possible for the time being, but will add in validation and testing loops.
% Image Parameters
rng(1);
imageSize = [132, 132];
bValue = [50 100 150 250 500 800]; % non-zero diffusion weightings
numbVal = length(bValue);
minDf = 0.0017;
maxDf = 0.107;
minf = 0.1;
maxf = 0.5;
minDs = 0.0003;
maxDs = 0.0017;
DfSim = minDf + (maxDf-minDf).*rand(10,1);
fSim = minf + (maxf-minf).*rand(10,1);
DsSim = minDs + (maxDs-minDs).*rand(10,1);
numIm = length(DfSim) * length(fSim) * length(DsSim); % number of 132 x 132 x 6 images
tissue = ones(imageSize);
bValue = reshape(bValue, [1,1,numbVal]); % Reshape bValue for matrix operation
% Prepare a directory to store the simulated images
outputDir = fullfile(tempdir, 'SimulatedDW-MRI');
if ~exist(outputDir, 'dir')
mkdir(outputDir);
end
% Initialize a table to store the image file paths and parameters
fprintf('Total simulated images: %d\n', numIm);
imageData = table('Size', [0 4],...
'VariableTypes', {'cell', 'double', 'double', 'double'},...
'VariableNames', {'imageFilePath', 'DfSim', 'fSim', 'DsSim'});
% Start the timer
tic;
% Loop through each combination of DfSim, fSim, and DsSim
imageIdx = 0;
S = zeros([imageSize length(bValue) numIm]);
for DfIdx = 1:length(DfSim)
for fIdx = 1:length(fSim)
for DsIdx = 1:length(DsSim)
imageIdx = imageIdx + 1;
% Calculate the diffusion signal for each b value for each channel
S(:,:,:,imageIdx) = tissue .* ((fSim(fIdx) .* exp(-bValue .* DfSim(DfIdx))) + ((1-fSim(fIdx)) .* exp(-bValue .* DsSim(DsIdx))));
% Track progress
fprintf('Processing image %d out of %d\n', imageIdx, numIm);
end
end
end
for imageIdx = 1:numIm
fileName = sprintf('%s/image%d.mat', outputDir, imageIdx); % Write the image to a .mat file
S_single = S(:,:,:,imageIdx);
save(fileName, 'S_single');
DfIdx = ceil(imageIdx / (length(fSim)*length(DsSim)));
fIdx = ceil((imageIdx - (DfIdx-1)*length(fSim)*length(DsSim)) / length(DsSim));
DsIdx = imageIdx - (DfIdx-1)*length(fSim)*length(DsSim) - (fIdx-1)*length(DsSim);
imageData(imageIdx, :) = {fileName, DfSim(DfIdx), fSim(fIdx), DsSim(DsIdx)};
fprintf('Saving image %d out of %d\n', imageIdx, numIm);
end
elapsedTime = toc;
fprintf('Computation time: %.2f seconds\n', elapsedTime);
%% Split data in training, validation, and testing sets
trainSplit = 0.8;
valSplit = 0.1;
testSplit = 0.1;
n = height(imageData);
idx = randperm(n);
trainIdx = idx(1:round(trainSplit*n));
valIdx = idx(round(trainSplit*n)+1:round((trainSplit+valSplit)*n));
testIdx = idx(round((trainSplit+valSplit)*n)+1:end);
imageDataTrain = imageData(trainIdx, :);
imageDataVal = imageData(valIdx, :);
imageDataTest = imageData(testIdx, :);
trainImds = fileDatastore(imageDataTrain.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
trainLabelsDatastore = arrayDatastore(imageDataTrain{:, {'DfSim', 'fSim', 'DsSim'}});
trainCombinedDatastore = combine(trainImds, trainLabelsDatastore);
valImds = fileDatastore(imageDataVal.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
valLabelsDatastore = arrayDatastore(imageDataVal{:, {'DfSim', 'fSim', 'DsSim'}});
valCombinedDatastore = combine(valImds, valLabelsDatastore);
testImds = fileDatastore(imageDataTest.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
testLabelsDatastore = arrayDatastore(imageDataTest{:, {'DfSim', 'fSim', 'DsSim'}});
testCombinedDatastore = combine(testImds, testLabelsDatastore);
%% Define the network layers
lgraph = layerGraph();
Layers = [
imageInputLayer([132 132 6],"Name","imageinput","Normalization","none")
convolution2dLayer([1 1],32,"Name","conv_1","Padding","same")
batchNormalizationLayer("Name","batchnorm_1")
leakyReluLayer("Name","relu_1")
dropoutLayer(0.02,"Name","dropout_1")
convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
leakyReluLayer("Name","relu_2")
dropoutLayer(0.02,"Name","dropout_2")
convolution2dLayer([1 1],64,"Name","conv_3","Padding","same")
batchNormalizationLayer("Name","batchnorm_2")
leakyReluLayer("Name","relu_3")
dropoutLayer(0.02,"Name","dropout_3")
convolution2dLayer([3 3],64,"Name","conv_4","Padding","same")
leakyReluLayer("Name","relu_4")
dropoutLayer(0.02,"Name","dropout_4")
convolution2dLayer([1 1],128,"Name","conv_5","Padding","same")
batchNormalizationLayer("Name","batchnorm_3")
leakyReluLayer("Name","relu_5")
dropoutLayer(0.02,"Name","dropout_5")
convolution2dLayer([3 3],128,"Name","conv_6","Padding","same")
leakyReluLayer("Name","relu_6")
dropoutLayer(0.02,"Name","dropout_6")
convolution2dLayer([1 1],64,"Name","conv_7","Padding","same")
batchNormalizationLayer("Name","batchnorm_4")
leakyReluLayer("Name","relu_7")
dropoutLayer(0.02,"Name","dropout_7")
convolution2dLayer([3 3],64,"Name","conv_8","Padding","same")
leakyReluLayer("Name","relu_8")
dropoutLayer(0.02,"Name","dropout_8")
convolution2dLayer([1 1],32,"Name","conv_9","Padding","same")
batchNormalizationLayer("Name","batchnorm_5")
leakyReluLayer("Name","relu_9")
dropoutLayer(0.02,"Name","dropout_9")
convolution2dLayer([3 3],32,"Name","conv_10","Padding","same")
leakyReluLayer("Name","relu_10")
dropoutLayer(0.02,"Name","dropout_10")
convolution2dLayer([1 1],4,"Name","conv_11","Padding","same")
sigmoidLayer("Name","sigmoid")];
lgraph = addLayers(lgraph,Layers);
dlnet = dlnetwork(lgraph);
plot(lgraph);
%% Training loop
numEpochs = 200;
miniBatchSize = 10;
initialLearnRate = 0.01;
decay = 0.00001;
gradDecay = 0.9;
sqGradDecay = 0.999;
mbq = minibatchqueue(trainCombinedDatastore,...
'MiniBatchSize', miniBatchSize,...
'MiniBatchFormat', {'SSCB', 'CB'}, ...
'OutputAsDlarray', [1, 1],...
'OutputEnvironment', 'auto');
averageGrad = [];
averageSqGrad = [];
numObservationsTrain = imageIdx;
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
plots = 'training-progress';
if strcmp(plots, 'training-progress')
figure
lineLossTrain = animatedline;
xlabel("Total Iterations")
ylabel("Loss")
end
epoch = 0;
iteration = 0;
start = tic;
% Loop over epochs.
while epoch < numEpochs
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq)
iteration = iteration + 1;
% Read mini-batch of data.
[dlX, dlT] = next(mbq);
[loss, gradients, state] = dlfeval(@modelLoss,dlnet,dlX);
dlnet.State = state;
% Determine learning rate for time-based decay learning rate schedule.
learnRate = initialLearnRate/(1 + decay*iteration);
% Update network parameters
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,gradients,averageGrad,averageSqGrad,...
iteration, learnRate, gradDecay, sqGradDecay);
% Extract weights of first convolution layer
conv1Weights = dlnet.Layers(2).Weights;
% Print or save the weights
disp('Weights of conv_1 layer:');
disp(conv1Weights);
if strcmp(plots, 'training-progress')
D = duration(0,0,toc(start),'Format','hh:mm:ss');
addpoints(lineLossTrain, iteration, double(gather(extractdata(loss))));
title("Epoch: " + epoch + " , Elapsed: " + string(D));
drawnow
end
end
end
%% Custom loss function
function [loss, gradients,state] = modelLoss(dlnet, dlX)
% Forward data through network.
[dlY, state] = forward(dlnet, dlX);
% Calculate parameter maps
fMap = dlY(:,:,1,:).*0.5;
DfMap = dlY(:,:,2,:).*0.107;
S0Map = (dlY(:,:,3,:).*0.6) + 0.7;
DsMap = dlY(:,:,4,:).*0.0017;
% diffusion weightings
dlB = [50 100 150 250 500 800];
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
% Calculate the mse loss
loss = mse(Spred, dlX);
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss, dlnet.Learnables);
end
0 Comments
Accepted Answer
Richard
on 26 Jun 2023
These lines of code:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
are creating a variable, Spred, that does not contain a traced dependency on the output of the network. This means that your mse() call is in fact only tracing a dependency on the original input dlX, therefore the gradients of the loss with respect to learnables is zeros.
Try this instead to create an Spred that incorporates the dependency on the network output:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX), 'like', dlX);
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
The 'like' syntax for zeros() constructs a zeros dlarray that is tracing, like its input, and your indexing within the loop will then be captured. In your original version, because Spred is created as a plain double array, the indexing which places values into Spred(:,:,b,:) is casting the computed and traced right-hand side into a plain double value which loses the trace information that dlgradient depends on.
Incidentally I think you can also remove the loop entirely by reshaping dlB into a 3D vector and relying on implicit expansion, which should be faster:
dlB = reshape(dlB, 1,1,[]);
Spred = S0Map .* (fMap.*exp(-dlB.*DfMap) + (1 - fMap).*exp(-dlB.*DsMap));
2 Comments
Richard
on 26 Jun 2023
Hi Marissa,
10 samples is quite a small minibatchsize and I think this is causing you to see a lot of noise in the gradients. When I increase the minibatchsize to 64 I see a much smoother curve::
More Answers (1)
Aniketh
on 25 Jun 2023
A very probable cause for this, and what I have exeperienced myself a few times is Initialization, check the initialization of your network's weights. If the weights are initialized too small, it can lead to vanishing gradients. Consider using a suitable initialization method, such as Xavier or He initialization, which helps to maintain a reasonable range for the weights.
Another thing you could consider is your Network architecture, evaluate the depth and complexity of your network architecture. Very deep networks are more susceptible to vanishing gradients. If your network is too deep, consider reducing the number of layers or introducing skip connections (e.g., residual connections) to facilitate gradient flow.
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!