Clear Filters
Clear Filters

multiple input to a pre-trained model

3 views (last 30 days)
Rayan Matlob
Rayan Matlob on 5 Jul 2022
Edited: Rayan Matlob on 6 Jul 2022
I have three classes folders (Good, Moderate and Severe)
Each class folder of them has (5 subolders) which are (Original images, Red, Blue, Green, HUE, Value),
where (Red, Blue, Green, HUE,Value) are subolders contain images after applying filters on the (Original images folder).
I am using a pre-trained model (resnet50 or any other model you suggest), all images in all the folders are numbered in the same sequence (each subfolder contains images from 1 to 200).
How to train the model by taking each single image from subfolder(Original images),and to apply it in parralel with the images from the other subfoldere (Red, Blue, Green, HUe,Value) to the input of the model.
Note: for the validation, i need to use only the (original_images folder) and the model should fetch the other images from the other subfolders
Next is the matlab code, thanks:
  1 Comment
Rayan Matlob
Rayan Matlob on 6 Jul 2022
Edited: Rayan Matlob on 6 Jul 2022
imds = imageDatastore('C:\Users\Rayan\Desktop\9_8_balance_data\R_9_1_GSM_3', ...
'IncludeSubfolders',true, ...
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.77,'randomized');
numTrainImages = numel(imdsTrain.Labels);
net = resnet50;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1,numClasses, ...
'Name','new_conv', ...
'WeightLearnRateFactor',10, ...
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
layers = lgraph.Layers;
connections = lgraph.Connections;
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain)
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',60, ...
'InitialLearnRate',0.00065, ...
'Shuffle','every-epoch', ...
'ValidationFrequency',valFrequency, ...
'ValidationData',augimdsValidation, ...
'Verbose',false, ...
net = trainNetwork(augimdsTrain,lgraph,options);

Sign in to comment.

Answers (0)


Find more on Deep Learning Toolbox in Help Center and File Exchange




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!