How to classify with DAG network from checkpoint
13 views (last 30 days)
Show older comments
I want to use classify() with DAG network from checkpoint network.
I trained inceptionv3 by transfer learning for a long epochs and it was successed. I set 'CheckpointPath' and have networks at each epoch. I want to evaluate these networks, so I loaded one and used classify(). But error message occuerd and it said "Use trainNetwork". How can I use classify() with network loaded from checkpoint?
3 Comments
carlos arizmendi
on 23 Nov 2019
I have now the same problem classifing, how did you fix this bug? Thanks a lot.
Accepted Answer
Naoya
on 15 Oct 2018
Thank you very much for providing the details.
The checkpoint network containing BatchNormalization layers is not supported on the current latest release (R2018b). I will forward this functionality as an enhancement request to our development team.
We applogize for causing inconvenience on the current checkpoint functionality.
5 Comments
carlos arizmendi
on 23 Nov 2019
i have the same problem just rigth now and i am running over 2019b deep learning toolbox version, ¡how can i fix the problem?, i nedd this as quick as posible. Thank you
Error using nnet.internal.cnn.layer.BatchNormalization/predict (line 135)
Unable to use networks with batch normalization layers before training is complete. Use trainNetwork to complete
network training.
Error in nnet.internal.cnn.DAGNetwork/predict (line 383)
outputActivations = thisLayer.predict(XForThisLayer);
Error in DAGNetwork/predict (line 622)
YBatch = predictNetwork.predict(X);
Error in DAGNetwork/classify (line 693)
scores = this.predict( X, varargin{:} );
Wes Baldwin
on 7 Jul 2020
I just had this same issue in 2019b. How is this an enhancement? This is a bug that needs fixed!
More Answers (3)
Katja Mogalle
on 30 Apr 2021
@Gediminas Simkus had the right idea for the workaround. I can sketch this out a bit more.
Background information
To make predictions with the network after training, batch normalization requires a fixed mean and variance to normalize the data. By default, this fixed mean and variance is calculated from the training data at the very end of training using the entire training data set. But when using checkpointing, the end of training isn't reached so the mean and variance values are not set.
Two possible solutions
There are two things you can try in order to use checkpoint networks for inference:
- Since R2021a, running statistics can be enabled for batch normalization layers. The batch normalization statistics are then calculated during training and not at the end of training. The checkpoint networks can be used directly without further modification. To do this, set the ‘BatchNormalizationStatistics’ name-value pair in trainingOptions to ‘moving’ when training the network with checkpointing.
- Use trainNetwork with minimal training to convert the checkpoint network into a network with fixed batch normalization mean and variance that can be used for inference. The workaround is based on the process to Resume Training from Checkpoint Network but with some slight tweaks in order to modify the checkpointed network as little as possible.
Example steps for second workaround using trainNetwork (tested in R2020a and R2020b)
Load the checkpoint network into the workspace (replace this with your own file).
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
Specify the training options such that training is only run for one iteration, the input data statistics of the input layer are not recomputed, and the learnable parameters are only changed minimally.
options = trainingOptions('sgdm', ...
'InitialLearnRate',eps, ...
'ResetInputNormalization',false,...
'OutputFcn',@(~)true );
Now “resume” training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
The returned network can be used for inference.
YPred = classify(net2,XTrain);
I hope this helps.
3 Comments
Katja Mogalle
on 19 Oct 2021
The option 'ResetInputNormalization' of training options was added in R2019b.
Nithin M
on 29 Oct 2021
Thank you for the detailed post.
I have a query. What is the impact of using BatchNormalizationStatistics as moving on time of training? whether it will be increased considerably or wont have much effect?
AnaMota
on 27 Apr 2021
Any solution on this? I am facing the same issue with MATLAB2020...
0 Comments
Andrea Daou
on 8 Oct 2021
Hello,
I know an answer was accepted for this question but I have a response that might be useful.
If the use of network from checkpoint does not work in your MATLAB version, you can write a function similar to the one in https://fr.mathworks.com/help/deeplearning/ug/customize-output-during-deep-learning-training.html .
For example, instead of being based on Validation Accuracy, it can be based on Validation Loss.
function stop = stopIfValidationLossNotDecreasing(info,N,StartPoint)
stop = false;
% Keep track of the validation loss and the number of successive validations for which
% there has not been a decrease in the loss.
persistent ValLoss
persistent valLag
% Clear the variables when training starts.
if info.State == "start"
ValLoss = StartPoint; % Value chosen depending on the problem case; check first validation loss.
valLag = 0;
elseif ~isempty(info.ValidationLoss)
% Compare the current validation loss to the last validation loss; if
% the new validation loss is less than the validation loss that
% precedes it then reset valLag else increment valLag by 1. Now the new
% ValLoss to compare with is the last one reached.
if info.ValidationLoss < ValLoss
valLag = 0;
ValLoss = info.ValidationLoss;
else
valLag = valLag + 1;
ValLoss = info.ValidationLoss;
end
% If the validation lag is at least N, that is, the validation loss
% has not decreased for at least N validations in a row, then return true and
% stop training.
if valLag >= N
stop = true;
end
end
end
1 Comment
Katja Mogalle
on 19 Oct 2021
Hi Andrea,
The training option "ValidationPatience" actually does exactly what you're showing in your code. To stop training when the loss on the validation set stops decreasing, simply specify validation data and a validation patience using the 'ValidationData' and the 'ValidationPatience' name-value pair arguments of trainingOptions, respectively. The validation patience is the number of times that the loss on the validation set can be larger than or equal to the previously smallest loss before network training stops.
See this doc page as a reference: https://uk.mathworks.com/help/deeplearning/ref/trainingoptions.html?s_tid=doc_ta#d123e136007
But perhaps I am not fully understanding what you are trying to achieve. In that case, perhaps you could provide some clarification?
Thanks
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!