Loss stops calculating with custom layer (weighted cross entropy function)

10 views (last 30 days)
Tony on 22 Oct 2019
Edited: Raza Ali on 26 Oct 2019
Hi All--I am relatively new to deep learning and have been trying to train existing networks to identify the difference between images classified as "0" or "1." I've mostly been trying to train AlexNet, and I have had a reasonable amount of success. My data is imbalanced so I am working on replacing the last layer with a custom weighted loss function.
I used the following link and code and was able to do replaced the last layer with a weighted loss function (https://www.mathworks.com/help/deeplearning/ug/create-custom-weighted-cross-entropy-classification-layer.html). Here is the code I ended up using:
classdef weightedClassificationLayer < nnet.layer.ClassificationLayer
% Row vector of weights corresponding to the classes in the
% training data.
function layer = weightedClassificationLayer(classWeights, name)
% layer = weightedClassificationLayer(classWeights) creates a
% weighted cross entropy loss layer. classWeights is a row
% vector of weights corresponding to the classes in the order
% that they appear in the training data.
% layer = weightedClassificationLayer(classWeights, name)
% additionally specifies the layer name.
% Set class weights.
layer.ClassWeights = classWeights;
% Set layer name.
if nargin == 2
layer.Name = name;
% Set layer description
layer.Description = 'Weighted cross entropy';
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the weighted cross
% entropy loss between the predictions Y and the training
% targets T.
N = size(Y,4);
Y = squeeze(Y);
T = squeeze(T);
W = layer.ClassWeights;
loss = -sum(W*(T.*log(Y)))/N;
function dLdY = backwardLoss(layer, Y, T)
% dLdX = backwardLoss(layer, Y, T) returns the derivatives of
% the weighted cross entropy loss with respect to the
% predictions Y.
[~,~,K,N] = size(Y);
Y = squeeze(Y);
T = squeeze(T);
W = layer.ClassWeights;
dLdY = -(W'.*T./Y)/N;
dLdY = reshape(dLdY,[1 1 K N]);
The checklayers function says it is good, but unfortunately when I pass it through to the trainNetwork function after a few iterations is stops calculating a loss and the network doesn't improve its accuracy. See the following picture:
Any ideas why it appers to quit calculating a loss function? Is there something wrong with my code? Is it because my variables classes have "0" in them which throws the log function? When I don't use the custom layer it usually gets up to around 75% accuracy. Any help is appreciated. Thanks! -Tony

Accepted Answer

Kaashyap Pappu
Kaashyap Pappu on 25 Oct 2019
Whenever the loss value is either “NaN” or “Inf”, the value is not plotted on the curve. On similar lines, the accuracy is near a value of 50%, which indicates the network is basically guessing the output. The class label “0” does not influence the loss either. It can be deduced that the network is unable to find the optimal solution for classification.
WeightedClassificationLayer weights are probably not able to compensate the class imbalance.
You could try to set “Verbose” in “trainingOptions” to “true” to get the training related data such as loss values, current epoch, validation loss in tabular form. Similarly, you can try to reduce the learning rate to a smaller value or apply another property “GradientThreshold” to limit the variation of parameters within the network.
Using this example, I was able to use the “weightedClassificationLayer” to obtain a 100% training accuracy using the following modification:
layers = [
weightedClassificationLayer([0.2 0.2 0.1 0.4 0.1])]; %randomly taken weights
Hope this helps!
Raza Ali
Raza Ali on 26 Oct 2019
Hi Tony,
Have you find the appropriate solution? I am also looking for same issue. In my case training doesnot start. Its gives error:
Error using trainNetwork (line 170)
Error using 'backwardLoss' in Layer weightedClassificationLayer. The function threw an error and could not be executed.
Caused by:
Error using .*
Matrix dimensions must agree.
Please help if you can...

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!