Make Predictions Using Model Function
This example shows how to make predictions using a model function by splitting data into mini-batches.
For large data sets, or when predicting on hardware with limited memory, make predictions by splitting the data into mini-batches. When making predictions with dlnetwork object, the minibatchpredict function automatically splits the input data into mini-batches. For model functions, you must split the data into mini-batches manually.
Create Model Function and Load Parameters
Load the model parameters from the MAT file digitsMIMO.mat. The MAT file contains the model parameters in a struct named parameters, the model state in a struct named state, and the class names in classNames.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;The model function model, listed at the end of the example, defines the model given the model parameters and state.
Load Data for Prediction
Load the digits data for prediction.
unzip("DigitsData.zip"); dataFolder = "DigitsData"; imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); numObservations = numel(imds.Files);
Make Predictions
Loop over the mini-batches of the test data and make predictions using a custom prediction loop.
Use minibatchqueue to process and manage the mini-batches of images. Specify a mini-batch size of 128. Set the read size property of the image datastore to the mini-batch size.
For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch(defined at the end of this example) to concatenate the data into a batch and normalize the images.Format the images with the dimensions
'SSCB'(spatial, spatial, channel, batch). By default, theminibatchqueueobject converts the data todlarrayobjects with underlying typesingle.Make predictions on a GPU if one is available. By default, the
minibatchqueueobject converts the output to agpuArrayif a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat="SSCB");
Loop over the mini-batches of data and make predictions using the model function. Use the scores2label function to determine the class labels. Store the predicted class labels.
doTraining = false; Y1 = []; Y2 = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. X = next(mbq); % Make predictions using the model function. [Y1Batch,Y2Batch] = model(parameters,X,doTraining,state); % Determine corresponding classes. Y1Batch = scores2label(Y1Batch,classNames); Y1 = [Y1 Y1Batch]; Y2Batch = extractdata(Y2Batch); Y2 = [Y2 Y2Batch]; end
View some of the images with their predictions.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--") hold off label = string(Y1(idx(i))); title("Label: " + label) end

Model Function
The function model takes the model parameters parameters, the input data X, the flag doTraining which specifies whether to model should return outputs for training or prediction, and the network state state. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.
function [Y1,Y2,state] = model(parameters,X,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; Y = dlconv(X,weights,bias,Padding="same"); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution, batch normalization (Skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; YSkip = dlconv(Y,weights,bias,Stride=2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else YSkip = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; Y = dlconv(Y,weights,bias,Padding="same",Stride=2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; Y = dlconv(Y,weights,bias,Padding="same"); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU Y = YSkip + Y; Y = relu(Y); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; Y1 = fullyconnect(Y,weights,bias); Y1 = softmax(Y1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; Y2 = fullyconnect(Y,weights,bias); end
Mini-Batch Preprocessing Function
The preprocessMiniBatch function preprocesses the data using the following steps:
Extract the data from the incoming cell array and concatenate into a numeric array. Concatenating over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Normalize the pixel values between
0and1.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
See Also
dlarray | dlgradient | dlfeval | sgdmupdate | dlconv | batchnorm | relu | fullyconnect | softmax | minibatchqueue | onehotdecode
Topics
- Define Custom Training Loops, Loss Functions, and Networks
- Define Model Loss Function for Custom Training Loop
- Train Network Using Model Function
- Update Batch Normalization Statistics Using Model Function
- Initialize Learnable Parameters for Model Function
- Specify Training Options in Custom Training Loop
- List of Functions with dlarray Support