Clear Filters
Clear Filters

how do i use k-fold cross validation in deep network designer?

141 views (last 30 days)
hello, i'm working for my project by using deep network designer to create U-net architecture model adapted of image regression. i need to do k-fold cross validation due to my train dataset;2D image pairs are small (just 30 pairs). can anybody please tell me what should i do for using k-fold for valid the model.

Answers (1)

Rahul
Rahul on 4 Jan 2023
"deepNetworkDesigner" do not provide k-fold cross validation as such. But as a workaround, you can perform following steps:
Assuming your input images if of size 28x28x1000 and your labels of size 28x28x1000.
imgs: 28 x 28 x 1000
labels: 28 x 28 x 1000
imgs = randi([0,255], 28, 28, 1, 1000); % 1000 input images of size 28x28 single plane
labels = randi([0,255], 28, 28, 1, 1000); % 1000 labeled images of size 28x28 single plane
params.LR = 0.001;
params.Maxepochs = 100;
params.num_batch = 16;
layers = % your CNN model layers
options = trainingOptions('adam',...
'InitialLearnRate',params.LR,...
'MaxEpochs', params.Maxepochs,...
'MiniBatchSize', params.num_batch); % change this as per your requirement
%% k-fold cross validation
kfold_val = 10; % 10-fold cross validation value
num_samples = 1000; % = size(labels, 4);
fold = cvpartition(num_samples, 'kfold', kfold_val); % performing k-fold cross validation
for ii = 1:kfold_val
train_idx = fold.training(ii);
validation_idx = fold.test(ii);
% extract training images and labels using train_idx
xtrain = imgs(:, :, :, train_idx);
ytrain = labels(:, :, :, train_idx);
% extract validation images and labels using validation_idx
xvalid = imgs(:, :, :, validation_idx);
yvalid = labels(:, :, :, validation_idx);
% train the CNN model
trained_net = trainNetwork(xtrain, ytrain, layers, options);
% test on validation images
Pred = predict(trained_net, xvalid);
% calculate loss between yvalid and Pred
end
Please check documentation links below for your reference:
You can also go through below MATLAB central page for more information.
Please note that this workflow is NOT designed by MathWorks and contact the author in case of any issues.

Categories

Find more on Statistics and Machine Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!