MATLAB Answers

How to add Distance transformation Map in loss function at classification layer.

7 views (last 30 days)
Raza Ali
Raza Ali on 27 Aug 2020
Commented: Raza Ali on 10 Sep 2020
Hi Everyone, I am trying to insert distance map information in loss fucntion. i am doing this in classification layer of CNN.
but when I calculate the disance map using "bwdist(Y)" commnad, during training process the MATLAB produce error
"Error using 'forwardLoss' in Layer ClassificationLayer. The function threw an error and could not be executed".
"Expected input image to be a 2-D real-valued, non-sparse gpuArray with underlying class uint8, uint16, uint32, int8, int16, int32, logical, single or double".
How can I add Distance transformation Map in loss fucntion. or how to resolve this issue?

  3 Comments

Raza Ali
Raza Ali on 10 Sep 2020
The classification layer code:
%%%%%%%%%%%%%%%%%%
classdef CEDLossLayer < nnet.layer.ClassificationLayer
properties
% Row vector of weights corresponding to the classes in the
% training data.
Beta=0.7;
end
methods
function layer = CEDLossLayer(name)
% layer = CEDLossLayer(name) creates a
% Set layer name.
if nargin == 2
layer.Name = name;
end
% Set layer description
layer.Description = 'cross entropy';
end
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the cross entropy loss between the predictions Y and the training
% targets T.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
loss_i = ((layer.Beta).*W.*T .* log(nnet.internal.cnn.util.boundAwayFromZero(Y)))+((1-layer.Beta).*(1-W).*(1-T) .* log(1-(nnet.internal.cnn.util.boundAwayFromZero(Y))));
loss = -sum( sum( sum( sum(loss_i, 3).*(1./N), 1), 2));
end
function dLdY = backwardLoss(layer, Y, T)
% dLdX = backwardLoss(layer, Y, T) returns the derivatives of
% cross entropy loss with respect to the
% predictions Y.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
dLdY= (-(W.*T./nnet.internal.cnn.util.boundAwayFromZero(Y))).*(1./N);
% dLdY= -(1./N).*((((layer.Beta).*T)./nnet.internal.cnn.util.boundAwayFromZero(Y))-((1-layer.Beta).*(1-T))./(1-nnet.internal.cnn.util.boundAwayFromZero(Y)));
end
end
end
Raza Ali
Raza Ali on 10 Sep 2020
%% Weight Map Function
function [weight]=WeightMap(gt);
% class balance weights w_c(x)
uvals=unique(gt);
wmp=zeros(1,length(uvals));
for uv=1:length(uvals)
wmp(uv)=1/sum(gt(:)==uvals(uv));
end
% this normalization is important!
%background pixels must have weight 1
wmp=wmp/max(wmp);
% wc=double(gt);
% wc=uint8(gt);
wc=zeros(size(gt));
for uv=1:length(uvals)
wc(gt==uvals(uv))=wmp(uv);
end
% cells instances for distance computation
cells=bwlabel(gt==1, 4);
% cells distance map
bwgt=zeros(size(gt));
maps=zeros(size(gt,1),size(gt,2),max(cells(:)));
if max(cells(:))>=2
for ci=1:max(cells(:))
maps(:,:,ci)=bwdist(cells==ci);
end
maps=sort(maps,3);
d1=maps(:,:,1);
d2=maps(:,:,2);
bwgt=10*exp(-((d1+d2).^2)./(2*25) ).*(cells==0)
end
% unet weights
weight=wc + bwgt;
end

Sign in to comment.

Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!