Why happens all the gradients of the generator are zero (from the beginning to the end) when training a GAN
    8 views (last 30 days)
  
       Show older comments
    
I want to train a generator which generates values of a sin function. However, when training a GAN, all the gradients of the generator are zero. I do not know what the problem is. Could anyone  help me?
The code is listed as follows:
batch_size = 64;
n_ideas = 5;
art_components = 15;
step = 2/(art_components-1);
points = -1:step:1;
paint_points = repmat(points,batch_size,1);
Generator = [
    featureInputLayer(n_ideas)
    fullyConnectedLayer(128)
    reluLayer
    fullyConnectedLayer(art_components)
    ];
Discriminator = [
    featureInputLayer(art_components)
    fullyConnectedLayer(128)
    reluLayer
    fullyConnectedLayer(1)
    sigmoidLayer
    ];
net_g = dlnetwork(Generator);
net_d = dlnetwork(Discriminator);
lr = 0.0001;
decay = 0.90;
sqdecay = 0.999;
avg_decay_g = [];
avd_sqdecay_g = [];
avg_decay_d = [];
avd_sqdecay_d = [];
for e=1:10000
    artis_paintings = dlarray(single(artist_work(art_components,paint_points)),'BC');
    % update learnable parameters of discriminator
    g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
    g_paintings = forward(net_g,g_ideas);
    [loss_d,gradient_d,score_d] = ...
        dlfeval(@d_loss,net_d,artis_paintings,g_paintings);
    [net_d, avg_decay_d, avd_sqdecay_d] = ...
        adamupdate(net_d,gradient_d,avg_decay_d,avd_sqdecay_d,e,lr,decay,sqdecay);
    % update learnable parameters of generator
    g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
    g_paintings = forward(net_g,g_ideas);
    prob_artist1 = forward(net_d,g_paintings);
    [loss_g,gradient_g,score_g] = ...
        dlfeval(@g_loss,net_g,prob_artist1);
    [net_g, avg_decay_g, avd_sqdecay_g] = ...
                adamupdate(net_g,gradient_g,avg_decay_g,avd_sqdecay_g,e,lr,decay,sqdecay);
end
function [loss_d,gradient_d,score_d] = ...
    d_loss(net_d,artis_paintings,g_paintings)
% calculate loss
prob_artist0 = forward(net_d,artis_paintings);
prob_artist1 = forward(net_d,g_paintings);
score_d = mean(1-prob_artist1);
loss_d = -mean(log(prob_artist0)) - mean(log(1-prob_artist1));
% calculate gradients
gradient_d = dlgradient(loss_d, net_d.Learnables);
end
function [loss_g,gradient_g,score_g] = ...
    g_loss(net_g,prob_artist1)
score_g = mean(prob_artist1);
% calculate gradients
loss_g = -mean(log(prob_artist1));
gradient_g = dlgradient(loss_g, net_g.Learnables);
end
function paintings=artist_work(art_components,paint_points)
r = 0.02 * randn(1,art_components);
paintings = sin(paint_points *pi) + r;
end
0 Comments
Accepted Answer
  Richard
    
 on 5 Nov 2022
        All of the calculations that are "between" the variables you want gradients with respect to, and the loss value, need to be contained inside the function that you pass to dleval.  If they are not, the dlgradient call will not know they have occurred and think there is no dependency between the outputs and inputs, hence gradients are all zero.
In this case, you must ensure that the "forward(net)" calls are inside the loss functions.  You have done this correctly for the discriminator loss, but for the generator loss you need to pass in both the generator and disciminator networks and call forward on each one inside g_loss:
function [loss_g,gradient_g,score_g] = g_loss(net_g,net_d,g_ideas)
    g_paintings = forward(net_g,g_ideas);
    prob_artist1 = forward(net_d,g_paintings);
    score_g = mean(prob_artist1);
    % calculate gradients
    loss_g = -mean(log(prob_artist1));
    gradient_g = dlgradient(loss_g, net_g.Learnables);
end
2 Comments
  Richard
    
 on 6 Nov 2022
				Thanks for the feedback @You Jinkun.  I will submit a documention enhancement request regarding this aspect of the dlfeval/dlgradient interaction.  
You can also use the "How useful was this information?" section at the bottom of any of our doc pages to directly submit feedback to our doc team if there is a specific page that you think could be improved (clicking on the rating opens a text field for submitting a specific comment).  
More Answers (0)
See Also
Categories
				Find more on Deep Learning Toolbox in Help Center and File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!
