Need to improve semantic segmentation using a deep network and transfer learning
2 views (last 30 days)
Show older comments
I'm trying to use a pre-trained network to do transfer learning for image segmentation to detect different terrain types, following the example in Mutlispectral Semantic Segmentation Using Deep Learning.
I've got a pre-trained U-net based on the data from the example that is decent (accuracy in the 70-90% range).
%data and helper functions from the example page
imds = imageDatastore("train_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
pxds = pixelLabelDatastore("train_labels.png",classNames,pixelLabelIds);
dsTrain = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=1000);
dsTrain2 = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=500); %use during transfer learning
val_ds = imageDatastore("val_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
val_pxds = pixelLabelDatastore("val_labels.png",classNames,pixelLabelIds);
dsVal2 = randomPatchExtractionDatastore(val_ds,val_pxds,[256,256],PatchesPerImage=100);%use during transfer learning
lgraph = unetLayers([256,256,6], 18, 'EncoderDepth', 4);
options = trainingOptions("sgdm",...
InitialLearnRate=0.05, ...
Momentum=0.9,...
L2Regularization=0.0001,...
MaxEpochs=10,...
MiniBatchSize=8,...
LearnRateSchedule="piecewise",...
Shuffle="every-epoch",...
GradientThresholdMethod="l2norm",...
GradientThreshold=0.05, ...
Plots="training-progress", ...
VerboseFrequency=20);
net = trainNetwork(dsTrain,lgraph,options);
save("my_multispectralUnet.mat", "net");
Results from training:

To test the transfer learning part, I replace the final three layers.
data = load("my_multispectralUnet.mat");
basic_trained = data.net;
layersTransfer = basic_trained.Layers(1:end-3);
numClasses = 18; %there are 18 classes in the data
layers = [
layersTransfer
convolution2dLayer(1,numClasses, 'BiasL2Factor',0,'Padding', 'same','Name','new_Final-ConvolutionLayer', ...
'WeightLearnRateFactor', 10, 'BiasLearnRateFactor',10);
softmaxLayer('Name','new_final_softmax')
pixelClassificationLayer('Name', 'new_final_classification')];
%have to add back in the connections that make it U shaped
lgraph = connectLayers(lgraph, 'Encoder-Stage-1-ReLU-2','Decoder-Stage-4-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-2-ReLU-2','Decoder-Stage-3-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-3-ReLU-2','Decoder-Stage-2-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-4-DropOut','Decoder-Stage-1-DepthConcatenation/in2');
options = trainingOptions('sgdm', ...
'MiniBatchSize',8, ...
'MaxEpochs',5, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',dsVal2, ...
'ValidationFrequency',floor(100/8), ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(dsTrain2,lgraph_pretraining,options);
save("trained_616.mat", "netTransfer");
I am continuing to train with patches taken from the same dataset (proof of concept), so I expect the network to be still at least as good as it was originally. But it is much worse.

Can anyone share insights as to what is going wrong? Am I specifying some parameters in the transfer part incorrectly?
Thanks
5 Comments
Ben
on 20 Jun 2023
@Allison - I suspect changing the final convolution is causing this issue, that layer has learnable parameters that have been trained and are getting replaced with randomly initialized values.
This causes 2 issues, firstly since those weights are now just randomly initialized the network outputs will not be accurate, then secondly during training the large errors will cause large updates to the rest of the network's pre-trained weights.
That first issue is necessary if you want to perform transfer learning from one set of pixel classes to another, but the second issue can be prevented by setting the WeightLearnRateFactor and BiasLearnRateFactor for all the pre-trained layers to a small value, even 0. I expect that by setting those factors to 0 the transfer learning training should be quicker too.
Answers (0)
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!