Computing Hessian by dlgradient
15 views (last 30 days)
Show older comments
Hi every one.
I am using a training loop for my model in which gradients are computing by dlgradient. As you know, dlgradient (through dlfeval) returns a TABLE in which the layers, parameters (weights and bias) and gradients' values are stored. Also, we know that dlgradient accepts "loss" as a SCALLER and dlnet.Learnables, data samples dlX and targets dlY for these computations. I am interested in computing Hesseian for a small network using dlX and dlY. In fact I am going to compute a sub-sampled Hessian if I uses mini-batch dlX. (SO, I do not have problem for storing this matrix then!). However, I do not know how I apply dlgradient one more time for computing Hessian. If someone knows, I would thankfull him/her.
0 Comments
Answers (1)
Yash
on 18 Dec 2023
Hi Mahsa,
To compute the Hessian using dlgradient, you can use the same approach as for computing gradients. However, instead of computing gradients for each parameter, you need to compute the second-order partial derivatives for each pair of parameters. You can use the dlgradient function twice, once for each parameter, and then compute the Hessian matrix using the second-order partial derivatives.
Here is a code snippet you can use as a reference to understand what I want to convey:
Assuming that dlnet is your network, dlX and dlY are your data samples and targets, and mse is your loss function.
% Define the loss function
loss = @(dlY, Y) mse(dlY, Y);
% Compute the gradients for each parameter
[grads, ~] = dlgradient(dlnet, dlX, 'Output', dlY, 'LossFunction', loss);
% Compute the Hessian matrix
H = zeros(numel(dlnet.Learnables), numel(dlnet.Learnables));
for i = 1:numel(dlnet.Learnables)
for j = i:numel(dlnet.Learnables)
% Compute the second-order partial derivative
hessian = dlgradient(grads(i), dlnet.Learnables(j), 'Output', dlY, 'LossFunction', loss);
H(i,j) = hessian;
H(j,i) = hessian;
end
end
The grads variable contains the gradients for each parameter, and the H variable contains the Hessian matrix.
Hope this helps!
0 Comments
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!