Main Content

Prune Filters in a Detection Network Using Taylor Scores

This example shows how to reduce network size and increase inference speed by pruning convolutional filters in a you only look once (YOLO) v3 object detection network.

Filter pruning is a compression technique that uses some criterion to identify and remove the least important filters in a network, reducing the overall memory footprint of the network without significant reduction in the network accuracy. The pruning algorithm used in this example is gradient-based and uses first-order Taylor expansion [1][2] to evaluate the importance of convolutional filters in a network. This example also shows how to generate code for the pruned network and deploy a processor-in-the-loop (PIL) executable to a Raspberry Pi® embedded target.

This example uses YOLO v3 detector trained on the Caltech Cars data set. For more information, see Object Detection Using YOLO v3 Deep Learning (Computer Vision Toolbox).

Load Network for Pruning

Load the trained network for pruning. The pretrained YOLO v3 detector in this example is based on SqueezeNet, and uses the feature extraction network in SqueezeNet with the addition of two detection heads at the end. The second detection head is twice the size of the first detection head, so it is better able to detect small objects.

For information on network training, see Object Detection Using YOLO v3 Deep Learning (Computer Vision Toolbox).

Download the yolov3SqueezeNetVehicleExample_21a.zip file containing the pretrained YOLO v3 network. This file is approximately 23MB in size. Download the file from the MathWorks website, then unzip the file.

fileName = matlab.internal.examples.downloadSupportFile("vision/data/","yolov3SqueezeNetVehicleExample_21aSPKG.zip");
unzip(fileName);
matFile = "yolov3SqueezeNetVehicleExample_21aSPKG.mat";
pretrained = load(matFile);
yolov3Detector = pretrained.detector;
net = yolov3Detector.Network
net = 
  dlnetwork with properties:

         Layers: [75×1 nnet.cnn.layer.Layer]
    Connections: [84×2 table]
     Learnables: [66×3 table]
          State: [6×3 table]
     InputNames: {'data'}
    OutputNames: {'customOutputConv1'  'customOutputConv2'}
    Initialized: 1

  View summary with summary.

Load and Prepare Vehicle Data

Load the training and validation data that will be used for pruning, fine-tuning, and retraining. This example uses a small labeled data set that contains 295 images. Many of these images come from the Caltech Cars 1999 and 2001 data sets, created by Pietro Perona and used with permission. Each image contains one or two labeled instances of a vehicle.

Unzip the vehicle images and load the vehicle ground truth data.

unzip("vehicleDatasetImages.zip");
data = load("vehicleDatasetGroundTruth.mat");
vehicleDataset = data.vehicleDataset;

Add the full path to the local vehicle data folder.

vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);

Split the data set into a training set for training the network, and a test set for evaluating the network. Use 60% of the data for training set and the rest for the test set.

shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices));
trainingDataTbl = vehicleDataset(shuffledIndices(1:idx), :);
testDataTbl = vehicleDataset(shuffledIndices(idx+1:end), :);

Create image and box label datastores.

imdsTrain = imageDatastore(trainingDataTbl.imageFilename);
imdsTest = imageDatastore(testDataTbl.imageFilename);
bldsTrain = boxLabelDatastore(trainingDataTbl(:, 2:end));
bldsTest = boxLabelDatastore(testDataTbl(:, 2:end));
trainingData = combine(imdsTrain, bldsTrain);
testData = combine(imdsTest, bldsTest);

Use validateInputData to detect invalid images, bounding boxes or labels. Any invalid samples must either be discarded or fixed for proper training.

validateInputData(trainingData);
validateInputData(testData);

Use transform function to apply custom data augmentations to the training data. The augmentData helper function, listed at the end of the example, applies the following augmentations to the input data.

  • Color jitter augmentation in HSV space

  • Random horizontal flip

  • Random scaling by 10 percent

augmentedTrainingData = transform(trainingData, @augmentData);

Use transform to preprocess the training data for computing the anchor boxes, as the training images used in this example are bigger than 227-by-227 and vary in size. Then, use the estimateAnchorBoxes function to estimate the anchor boxes. Specify the number of anchors as 6 to achieve a good tradeoff between number of anchors and mean IoU. To prevent the estimated anchor boxes from changing while tuning other hyperparameters set the random seed prior to estimation using rng.

networkInputSize = [227 227 3];
trainingDataForEstimation = transform(trainingData, @(data)preprocessData(data, networkInputSize));
numAnchors = 6;
[anchorBoxes, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors);

Specify anchorBoxes to use in both the detection heads. Select anchorBoxes for each detection head based on the feature map size. Use larger anchors at lower scale and smaller anchors at higher scale. To do so, sort the anchors with the larger anchor boxes first and assign the first three to the first detection head and the next three to the second detection head.

area = anchorBoxes(:, 1).*anchorBoxes(:, 2);
[~, idx] = sort(area, 'descend');
anchorBoxes = anchorBoxes(idx, :);
anchorBoxMasks = {[1,2,3] [4,5,6]};
classNames = trainingDataTbl.Properties.VariableNames(2:end);

Preprocess the augmented training data to prepare for training. The preprocessData helper function (defined at the end of this example) resizes the images to the network input size by maintaining the aspect ratio and scales the image pixels to the range [0 1].

augimdsTrain = transform(augmentedTrainingData, @(data)preprocessData(data, networkInputSize));
augimdsTest = transform(testData, @(data)preprocessData(data, networkInputSize));

Evaluate Detector Network Before Pruning

Use the evaluateDetectionPrecision function to measure the average precision of the trained network before pruning. The average precision provides a single number that incorporates the ability of the detector to make correct classifications (precision) and the ability of the detector to find all relevant objects (recall).

results = detect(yolov3Detector,testData,MiniBatchSize=16);
[apTrainedNet, recallTrainedNet, precisionTrainedNet] = evaluateDetectionPrecision(results,testData);
accuracyTrainedNet = mean(apTrainedNet)*100
accuracyTrainedNet = 79.9338

The precision-recall (PR) curve shows how precise a detector is at varying levels of recall. Ideally, the precision is 1 at all recall levels.

figure
plot(recallTrainedNet,precisionTrainedNet)
xlabel("Recall")
ylabel("Precision")
grid on
title("Average Precision = " + apTrainedNet)

Figure contains an axes object. The axes object with title Average Precision = 0.79934, xlabel Recall, ylabel Precision contains an object of type line.

Prune Network

Create a prunable object based on first-order Taylor approximation by using taylorPrunableNetwork (Deep Learning Toolbox). A taylorPrunableNetwork has similar properties and methods as a dlnetwork in addition to pruning specific properties and methods. The prunable object can be substituted for a dlnetwork in the custom training loop. Pruning is iterative; each time the loop runs, until a stopping criterion is met, the function removes a small number of the least important convolution filters and updates the network architecture.

prunableNet = taylorPrunableNetwork(net)
prunableNet = 
  TaylorPrunableNetwork with properties:

      Learnables: [66×3 table]
           State: [6×3 table]
      InputNames: {'data'}
     OutputNames: {'customOutputConv1'  'customOutputConv2'}
    NumPrunables: 3496

maxPrunableFilters = prunableNet.NumPrunables;

Specify Pruning Options

Set the pruning options.

  • maxPruningIterations defines the maximum number of iterations to be used in the pruning loop.

  • maxToPrune is the maximum number of filters to be pruned in each iteration of the pruning loop.

  • validationFrequency is the number of iterations to wait before validating the pruned network using the test data.

maxPruningIterations = 20;
maxToPrune = 64;
validationFrequency = 5;

Set the fine-tuning options.

  • Fine-tune the network via a custom training loop for 40 mini-batches in every pruning iteration.

  • Specify the options for SGDM optimization. Specify an initial learn rate of 0.00001 and momentum of 0.9. Set the L2 regularization factor to 0.0005. Initialize the velocity of gradient as []. This is used by SGDM to store the velocity of gradients.

  • Specify the penalty threshold as 0.5. Detections that overlap less than 0.5 with the ground truth are penalized.

  • Specify a mini-batch size of 16 to fine-tune the network.

numMinibatchUpdates = 40;
learnRate = 1e-5;
momentum = 0.9;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
miniBatchSize = 16;

Create the minibatchqueue

Use a minibatchqueue object to process and manage the mini-batches of images. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) which returns the batched images and bounding boxes combined with the respective class IDs.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). Do not add a format to the bounding boxes.

  • Specify the data type of the bounding boxes.

mbq = minibatchqueue(augimdsTrain, 2,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(images, boxes, labels) preprocessMiniBatch(images, boxes, labels, classNames), ...
    MiniBatchFormat=["SSCB", ""],...
    OutputCast=["", "double"]);

Prune Network Using Custom Pruning Loop

Initialize the training progress plots.

figure("Position",[10,10,700,700])
tl = tiledlayout(3,1);
lossAx = nexttile;
lineLossFinetune = animatedline(Color=[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Fine-Tuning Iteration")
ylabel("Loss")
grid on
title("Mini-Batch Loss during Pruning")
xTickPos = [];

accuracyAx = nexttile;
lineAccuracyPruning = animatedline(Color=[0.098 0.325 0.85]);
ylim([50 100])
xlabel("Pruning Iteration")
ylabel("Accuracy")
grid on
addpoints(lineAccuracyPruning, 0, accuracyTrainedNet)
title("Validation Accuracy After Pruning")

numPrunablesAx = nexttile;
lineNumPrunables = animatedline(Color=[0.4660 0.6470 0.1880]);
ylim([200 3600])
xlabel("Pruning Iteration")
ylabel("Prunable Filters")
grid on
addpoints(lineNumPrunables, 0, double(maxPrunableFilters))
title("Number of Prunable Convolution Filters After Pruning")

Prune the network. For each mini-batch in the pruning iteration, the following steps are used:

  • Evaluate the pruning activations, gradients of the pruning activations, model gradients, state, and loss using dlfeval and modelLossPruning functions.

  • Update the network state.

  • Apply a weight decay factor to the gradients to regularization for more robust training.

  • Update the network learnable parameters using the stochastic gradient descent with momentum (SGDM) algorithm.

  • Compute first-order Taylor scores and accumulate the score across previous minibatches of data.

  • Display the progress.

In a loop, alternate between fine-tuning and pruning.

start = tic;
iteration = 0;

for pruningIteration = 1:maxPruningIterations

    % Shuffle the data in the minibatch.
    shuffle(mbq);

    % Reset the velocity parameter for the SGDM solver in every pruning
    % iteration.
    velocity = [];

    % Loop over mini-batches.
    fineTuningIteration = 0;
    while hasdata(mbq)
        iteration = iteration + 1;
        fineTuningIteration = fineTuningIteration + 1;

        % Read mini-batch of data.
        [X, T] = next(mbq);

        % Evaluate the pruning activations, gradients of the pruning
        % activations, model gradients, state, and loss using dlfeval and
        % modelLossPruning functions.
        [loss, pruningGradients, netGradients, pruningActivations, state] = ...
            dlfeval(@modelLossPruning, prunableNet, X, T, anchorBoxes, ...
            anchorBoxMasks, penaltyThreshold);

        % Update the network state.
        prunableNet.State = state;

        % Apply L2 regularization.
        netGradients = dlupdate(@(g,w) g + l2Regularization*w, ...
            netGradients, prunableNet.Learnables);

        % Update the network parameters using the SGDM optimizer.
        [prunableNet, velocity] = sgdmupdate(prunableNet, ...
            netGradients, velocity, learnRate, momentum);

        % Compute first-order Taylor scores and accumulate the score across
        % previous mini-batches of data.
        prunableNet = updateScore(prunableNet, pruningActivations, pruningGradients);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossFinetune, iteration, double(loss.totalLoss))
        title(tl,"Processing Pruning Iteration: " + pruningIteration + " of " + maxPruningIterations + ...
            ", Elapsed Time: " + string(D))
        % Synchronize the x-axis of the accuracy plot with the loss plot.
        xlim(accuracyAx,lossAx.XLim)
        xlim(numPrunablesAx,lossAx.XLim)
        drawnow
        
        % Stop the fine-tuning loop when numMinibatchUpdates is reached.
        if (fineTuningIteration > numMinibatchUpdates)
            break
        end

    end

    % Prune filters based on previously computed Taylor scores.
    prunableNet = updatePrunables(prunableNet, MaxToPrune = maxToPrune);

    % Show results on validation data set in a subset of pruning
    % iterations.
    isLastPruningIteration = pruningIteration == maxPruningIterations;
    if (mod(pruningIteration, validationFrequency) == 0 || isLastPruningIteration)
        [ap,~,~] = modelAccuracy(prunableNet, augimdsTest, anchorBoxes, anchorBoxMasks, classNames, 16);
        accuracy = mean(ap)*100;
        addpoints(lineAccuracyPruning, iteration, accuracy)
        addpoints(lineNumPrunables,iteration,double(prunableNet.NumPrunables))
    end

    % Set x-axis tick values at the end of each pruning iteration.
    xTickPos = [xTickPos, iteration]; %#ok<AGROW>
    xticks(lossAx,xTickPos)
    xticks(accuracyAx,[0,xTickPos])
    xticks(numPrunablesAx,[0,xTickPos])
    xticklabels(accuracyAx,["Unpruned",string(1:pruningIteration)])
    xticklabels(numPrunablesAx,["Unpruned",string(1:pruningIteration)])
    drawnow
end

Figure contains 3 axes objects. Axes object 1 with title Mini-Batch Loss during Pruning, xlabel Fine-Tuning Iteration, ylabel Loss contains an object of type animatedline. Axes object 2 with title Validation Accuracy After Pruning, xlabel Pruning Iteration, ylabel Accuracy contains an object of type animatedline. Axes object 3 with title Number of Prunable Convolution Filters After Pruning, xlabel Pruning Iteration, ylabel Prunable Filters contains an object of type animatedline.

During each pruning iteration, the validation accuracy may reduce because of changes in the network structure when the convolutional filters are pruned. To minimize loss accuracy, it is recommended to retrain the network after pruning.

Once pruning is complete, convert the deep.prune.TaylorPrunableNetwork object back to a dlnetwork for retraining and further analysis.

prunedNet = dlnetwork(prunableNet);
save("prunedNet","prunedNet");

Retrain Pruned Network

The pruning process can cause the prediction accuracy to decrease. Try to improve the prediction accuracy by retraining the network using a custom training loop.

Specify Training Options

Specify the options to use during retraining.

  • Specify the options for SGDM optimization. Specify an initial learn rate of 0.00001 and momentum of 0.9. Set the L2 regularization factor to 0.0005. Initialize the velocity of gradient as []. This is used by SGDM to store the velocity of gradients.

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) which returns the batched images and bounding boxes combined with the respective class IDs.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). Do not add a format to the bounding boxes.

  • Specify the data type of the bounding boxes.

velocity = [];
momentum = 0.9;
numEpochs = 10;
l2Regularization = 0.0005;
mbq = minibatchqueue(augimdsTrain, 2,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(images, boxes, labels) preprocessMiniBatch(images, boxes, labels, classNames), ...
    MiniBatchFormat=["SSCB", ""],...
    OutputCast=["", "double"]);

Train Network Using Custom Training Loop

Initialize the training progress plot.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

For each epoch, loop over mini-batches while data is still available in the minibatchqueue. Update the network parameters using the SGDM algorithm.

iteration = 0;
start = tic;
prunedDetectorNet = prunedNet;

for i = 1:numEpochs
    % Shuffle the data in the minibatch.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X, T] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [loss, gradients, state] = dlfeval(@modelLossTraining, prunedDetectorNet,...
            X, T, anchorBoxes, anchorBoxMasks, penaltyThreshold);

        % Update the network state.
        prunedDetectorNet.State = state;

        % Apply L2 regularization.
        gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, prunedDetectorNet.Learnables);

        % Update the network parameters using the SGDM optimizer.
        [prunedDetectorNet, velocity] = sgdmupdate(prunedDetectorNet, gradients, velocity, learnRate, momentum);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(loss.totalLoss))
        title("Retraining After Pruning" + newline + "Epoch: " + numEpochs + ", Elapsed: " + string(D))
        drawnow
    end
end

Figure contains an axes object. The axes object with title Retraining After Pruning Epoch: 10, Elapsed: 00:02:30, xlabel Iteration, ylabel Loss contains an object of type animatedline.

prunedyolov3ObjectDetector = yolov3ObjectDetector(prunedDetectorNet,classNames,yolov3Detector.AnchorBoxes);
save("prunedyolov3","prunedyolov3ObjectDetector");

Compare Original Network and Pruned Network

Determine the impact of pruning on each layer.

originalNetFilters = numConvLayerFilters(net);
prunedNetFilters = numConvLayerFilters(prunedDetectorNet);
convFilters = join(originalNetFilters,prunedNetFilters,Keys="Row");

Visualize the number of filters in the original network and in the pruned network.

figure("Position",[10,10,900,900])
bar([convFilters.(1),convFilters.(2)])
xlabel("Layer")
ylabel("Number of Filters")
title("Number of Filters Per Layer")
xticks(1:(numel(convFilters.Row)))
xticklabels(convFilters.Row)
xtickangle(90)
ax = gca;
ax.TickLabelInterpreter = "none";
legend("Original Network Filters","Pruned Network Filters","Location","southoutside")

Figure contains an axes object. The axes object with title Number of Filters Per Layer, xlabel Layer, ylabel Number of Filters contains 2 objects of type bar. These objects represent Original Network Filters, Pruned Network Filters.

Next, compare the accuracy of the original network and the pruned network. The average precision provides a single number that incorporates the ability of the detector to make correct classifications (precision) and the ability of the detector to find all relevant objects (recall).

[apPrunedNet,recallPrunedNet,precisionPrunedNet] = modelAccuracy(prunedDetectorNet, augimdsTest, anchorBoxes, anchorBoxMasks, classNames, 16);
accuracyPrunedNet = mean(apPrunedNet)*100
accuracyPrunedNet = 83.3012

The precision-recall (PR) curve is a good way to evaluate the performance of the object detector. Ideally the precision is 1 for all levels of recall. The pruned object detector has lost some precision but can still be considered good as its precision stays high when the recall increases.

figure
plot(recallTrainedNet,precisionTrainedNet,recallPrunedNet,precisionPrunedNet)
xlabel("Recall")
ylabel("Precision")
grid on
title("Precision Comparison of Original and Pruned Network")
legend("Original Network","Pruned Network");

Figure contains an axes object. The axes object with title Precision Comparison of Original and Pruned Network, xlabel Recall, ylabel Precision contains 2 objects of type line. These objects represent Original Network, Pruned Network.

Next, estimate the model parameters for the original network and the pruned network to understand the impact of pruning on the overall network learnables and size.

analyzeNetworkMetrics(net,prunedDetectorNet,accuracyTrainedNet,accuracyPrunedNet)
ans=3×3 table
                         Network Learnables    Approx. Network Memory (MB)    Accuracy
                         __________________    ___________________________    ________

    Original Network         6.4158e+06                   24.475               79.934 
    Pruned Network           1.8566e+06                   7.0824               83.301 
    Percentage Change           -71.062                  -71.062               4.2128 

Deploy Pruned YOLOv3 Network to Raspberry Pi

Optionally, you can use MATLAB® Coder™ to generate C++ code for the pruned network taking advantage of the ARM® Compute Library. The generated code can be integrated into your project as source code, static or dynamic libraries, or an executable that you can deploy to a variety of ARM CPU platforms such as Raspberry Pi. This example uses the PIL based workflow to generate a MEX function, which in turn calls the executable generated on a Raspberry pi from MATLAB.

Third-Party Prerequisites

PIL MEX Function

In this example, you generate code for the entry-point function yolov3Raspi. This function uses the coder.loadDeepLearningNetwork function to load a deep learning model and to construct and set up a CNN class. Then the entry-point function detects vehicles in the input and returns an output image displaying the detections.

type yolov3Raspi.m
function outImg = yolov3Raspi(in,matFile)

%   Copyright 2022 The MathWorks, Inc.

persistent yolov3Obj;

if isempty(yolov3Obj)
    yolov3Obj = coder.loadDeepLearningNetwork(matFile);
end

% Call to detect method.
[bboxes,~,labels] = detect(yolov3Obj,in,'Threshold',0.5);

% Convert categorical labels to cell array of charactor vectors.
labels = cellstr(labels);

% Annotate detections in the image.
outImg = insertObjectAnnotation(in,'rectangle',bboxes,labels);

To generate a PIL MEX function, create a code configuration object for a static library and set the verification mode to 'PIL'. Set the target language to C++.

cfg = coder.config("lib",ecoder=true);
cfg.VerificationMode = "PIL";
cfg.TargetLang = "C++";

Create a deep learning configuration object for the ARM Compute library. Specify the library version and arm architecture. For this example, suppose that the ARM Compute Library in the Raspberry Pi hardware is version 20.02.1.

dlcfg = coder.DeepLearningConfig("arm-compute");
dlcfg.ArmComputeVersion = "20.02.1";
dlcfg.ArmArchitecture = "armv7";

Set the DeepLearningConfig property of cfg to dlcfg.

cfg.DeepLearningConfig = dlcfg;

Use the MATLAB Support Package for Raspberry Pi function, raspi, to create a connection to the Raspberry Pi. In the following code, replace:

  • raspiname with the name of your Raspberry Pi

  • username with your user name

  • password with your password

r = raspi("raspiname","username","password");

Then, create a coder.Hardware object for Raspberry Pi and attach it to the code generation configuration object.

hw = coder.hardware("Raspberry Pi");
cfg.Hardware = hw;

Generate a PIL MEX function for the original network in yolov3SqueezeNetVehicleExample_21aSPKG.mat by using the codegen command.

codegen -config cfg yolov3Raspi -args {ones(227,227,3,'single'),coder.Constant("yolov3SqueezeNetVehicleExample_21aSPKG.mat")}

Read a sample image and call the generated PIL MEX function yolov3Raspi_pil. The PIL MEX function launches the yolov3Raspi.elf executable on the Raspberry Pi and returns the results of the execution to MATLAB.

data = read(augimdsTest);
I = data{1};
tic;
detectedImage = yolov3Raspi_pil(I,"yolov3SqueezeNetVehicleExample_21aSPKG.mat");
execTimeOriginalNet = toc;
clear yolov3Raspi_pil;
imshow(detectedImage);
title("Execution Time of Original Network = "+execTimeOriginalNet+"s");
saveas(gcf,"DetectionResultsOriginalNet.png");
close(gcf);
imshow("DetectionResultsOriginalNet.png");

Figure contains an axes object. The axes object contains an object of type image.

Then, generate a PIL MEX function for the pruned network in prunedyolov3.mat by using the codegen command.

codegen -config cfg yolov3Raspi -args {ones(227,227,3,'single'),coder.Constant("prunedyolov3.mat")}

Run the generated PIL MEX.

tic;
detectedImage = yolov3Raspi_pil(I,"prunedyolov3.mat");
execTimeOriginalNet = toc;
clear yolov3Raspi_pil
imshow(detectedImage);
title("Execution Time of Pruned Network = "+execTimeOriginalNet+"s");
saveas(gcf,"DetectionResultsPrunedNet.png");
close(gcf);
imshow("DetectionResultsPrunedNet.png");

Figure contains an axes object. The axes object contains an object of type image.

Helper Functions

Model Gradients Function for Fine-Tuning and Pruning

The function modelLossPruning takes as input a deep.prune.TaylorPrunableNetwork object prunableNet, a mini-batch of input data X with corresponding ground truth boxes T, anchor boxes, masks, penalty threshold and returns the loss, the gradients of the loss with respect to the pruning activations, gradients of loss with respect to the learnable parameters in prunableNet, pruning activations, and the network state.

function [loss, pruningGradients, netGradients, pruningActivations, state] = modelLossPruning(prunableNet, X, T, anchors, mask, penaltyThreshold)

inputImageSize = size(X,1:2);

% Gather the ground truths for post processing.
YTrain = gather(extractdata(T));

% Extract the predictions from the network.
[YPredCell, state, pruningActivations] = yolov3ForwardGate(prunableNet, X, mask);

% Gather the activations for post processing and extract dlarray data.
gatheredPredictions = cellfun(@ gather, YPredCell(:,1:6),'UniformOutput',false);
gatheredPredictions = cellfun(@ extractdata, gatheredPredictions, 'UniformOutput', false);

% Convert predictions from grid cell coordinates to box coordinates.
tiledAnchors = generateTiledAnchors(gatheredPredictions(:,2:5),anchors,mask);
gatheredPredictions(:,2:5) = applyAnchorBoxOffsets(tiledAnchors, gatheredPredictions(:,2:5), inputImageSize);

% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions, YTrain, inputImageSize, anchors, mask, penaltyThreshold);

% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;

loss.boxLoss = boxLoss;
loss.objLoss = objLoss;
loss.clsLoss = clsLoss;
loss.totalLoss = totalLoss;

% Differentiate loss w.r.t learnables and activations
[netGradients, pruningGradients] = dlgradient(totalLoss, prunableNet.Learnables, pruningActivations);

end

Model Gradients Function for Retraining

The function modelLossTraining takes as input a dlNetwork object net, a mini-batch of input data X with corresponding ground truth boxes T, anchor boxes, masks, penalty threshold and returns the loss, gradients of loss with respect to the learnable parameters in net, and the network state.

function [loss, gradients, state] = modelLossTraining(net, X, T, anchors, mask, penaltyThreshold)

inputImageSize = size(X,1:2);

% Gather the ground truths for post processing.
YTrain = gather(extractdata(T));

% Extract the predictions from the network.
[YPredCell, state] = yolov3Forward(net,X,mask);

% Gather the activations for post processing and extract dlarray data.
gatheredPredictions = cellfun(@ gather, YPredCell(:,1:6),'UniformOutput',false);
gatheredPredictions = cellfun(@ extractdata, gatheredPredictions, 'UniformOutput', false);

% Convert predictions from grid cell coordinates to box coordinates.
tiledAnchors = generateTiledAnchors(gatheredPredictions(:,2:5),anchors,mask);
gatheredPredictions(:,2:5) = applyAnchorBoxOffsets(tiledAnchors, gatheredPredictions(:,2:5), inputImageSize);

% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions, YTrain, inputImageSize, anchors, mask, penaltyThreshold);

% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;

loss.boxLoss = boxLoss;
loss.objLoss = objLoss;
loss.clsLoss = clsLoss;
loss.totalLoss = totalLoss;

% Differentiate loss w.r.t learnables
gradients = dlgradient(totalLoss, net.Learnables);

end

Mini-Batch Preprocessing Function

The preprocessMiniBatch function preprocesses a mini-batch of data and returns the batched images and bounding boxes combined with the respective class IDs.

function [X, T] = preprocessMiniBatch(data, groundTruthBoxes, groundTruthClasses, classNames)
% Returns images combined along the batch dimension in XTrain and
% normalized bounding boxes concatenated with classIDs in YTrain.

% Concatenate images along the batch dimension.
X = cat(4, data{:,1});

% Get class IDs from the class names.
classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
[~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);

% Append the label indexes and training image size to scaled bounding boxes
% and create a single cell array of responses.
combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
len = max( cellfun(@(x)size(x,1), combinedResponses ) );
paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
T = cat(4, paddedBBoxes{:,1});
end

Evaluate Model Accuracy

The modelAccuracy computes the accuracy of the network on the data set.

function [ap, recall, precision] = modelAccuracy(net, augimds, anchorBoxes, anchorBoxMasks, classNames, miniBatchSize)
% EVALUATE computes model accuracy on the dataset 'augimds'.
% Create a table to hold the bounding boxes, scores, and labels returned by
% the detector.
results = table('Size', [0 3], ...
    'VariableTypes', {'cell','cell','cell'}, ...
    'VariableNames', {'Boxes','Scores','Labels'});
mbqTest = minibatchqueue(augimds, 1, ...
    "MiniBatchSize", miniBatchSize, ...
    "MiniBatchFormat", "SSCB");

% Run detector on images in the test set and collect results.
while hasdata(mbqTest)
    % Read the datastore and get the image.
    XTest = next(mbqTest);

    % Run the detector.
    [bboxes, scores, labels] = yolov3Detect(net, XTest, net.OutputNames', anchorBoxes, anchorBoxMasks, 0.5, 0.5, classNames);

    % Collect the results.
    tbl = table(bboxes, scores, labels, 'VariableNames', {'Boxes','Scores','Labels'});
    results = [results; tbl];%#ok<AGROW>
end

% Evaluate the object detector using Average Precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(results, augimds);
end

Evaluate Number of Filters in Convolution Layers

The numConvLayerFilters function returns the number of filters in each convolution layer.

function convFilters = numConvLayerFilters(net)
numLayers = numel(net.Layers);
convNames = [];
numFilters = [];
% Check for convolution layers and extract the number of filters.
for cnt = 1:numLayers
    if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer")
        sizeW = size(net.Layers(cnt).Weights);
        numFilters = [numFilters; sizeW(end)];%#ok<AGROW>
        convNames = [convNames; string(net.Layers(cnt).Name)];%#ok<AGROW>
    end
end
convFilters = table(numFilters,RowNames=convNames);
end

Evaluate the Network Statistics of Original Network and Pruned Network

The analyzeNetworkMetrics function takes input as the original network, pruned network, accuracy of original network and the accuracy of the pruned network and returns the different statistics like network learnables, network memory and the accuracy on the test data in form of a table.

function [statistics] = analyzeNetworkMetrics(originalNet,prunedNet,accuracyOriginal,accuracyPruned)

originalNetMetrics = estimateNetworkMetrics(originalNet);
prunedNetMetrics = estimateNetworkMetrics(prunedNet);

% Accuracy of original network and pruned network
perChangeAccu = 100*(accuracyPruned - accuracyOriginal)/accuracyOriginal;
AccuracyForNetworks = [accuracyOriginal;accuracyPruned;perChangeAccu];

% Total learnables in both networks
originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables;
LearnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables];

% Approximate parameter memory
approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory;
NetworkMemory = [ approxOriginalMemory; approxPrunedMemory; percentageChangeMemory];

% Create the summary table
statistics = table(LearnablesForNetwork,NetworkMemory,AccuracyForNetworks, ...
    'VariableNames',["Network Learnables","Approx. Network Memory (MB)","Accuracy"], ...
    'RowNames',{'Original Network','Pruned Network','Percentage Change'});

end

Augmentation and Data Processing Functions

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.
data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            'Contrast',0.0,...
            'Hue',0.1,...
            'Saturation',0.2,...
            'Brightness',0.2);
    end

    % Randomly flip image.
    tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
    rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
    I = imwarp(I,tform,'OutputView',rout);

    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25);
    bboxes = round(bboxes);
    labels = labels(indices);

    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I, bboxes, labels};
    end
end
end

function data = preprocessData(data, targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.
for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);

    % Convert an input image with single channel to 3 channels.
    if numel(imgSize) < 3
        I = repmat(I,1,1,3);
    end
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);

    data(ii, 1:2) = {I, bboxes};
end
end

Utility Functions

function YPredCell = applyActivations(YPredCell)
% Apply activation functions on YOLOv3 outputs.
YPredCell(:,1:3) = cellfun(@ sigmoid, YPredCell(:,1:3), 'UniformOutput', false);
YPredCell(:,4:5) = cellfun(@ exp, YPredCell(:,4:5), 'UniformOutput', false);
YPredCell(:,6) = cellfun(@ sigmoid, YPredCell(:,6), 'UniformOutput', false);
end

function tiledAnchors = applyAnchorBoxOffsets(tiledAnchors,YPredCell,inputImageSize)
% Convert grid cell coordinates to box coordinates.
for i=1:size(YPredCell,1)
    [h,w,~,~] = size(YPredCell{i,1});
    tiledAnchors{i,1} = (tiledAnchors{i,1}+YPredCell{i,1})./w;
    tiledAnchors{i,2} = (tiledAnchors{i,2}+YPredCell{i,2})./h;
    tiledAnchors{i,3} = (tiledAnchors{i,3}.*YPredCell{i,3})./inputImageSize(2);
    tiledAnchors{i,4} = (tiledAnchors{i,4}.*YPredCell{i,4})./inputImageSize(1);
end
end

function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
% Mean squared error for bounding box position.
lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
boxLoss = lossX+lossY+lossW+lossH;
end

function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
% Binary cross-entropy loss for class confidence score.
clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'ClassificationMode','multilabel'),classPredCell,classTarget,boxMaskTarget(:,3)));
end

function predictions = extractPredictions(YPredictions, anchorBoxMask)
% Function extractPrediction extracts and rearranges the prediction outputs
% from YOLOv3 network.

predictions = cell(size(YPredictions, 1),6);
for ii = 1:size(YPredictions, 1)
    % Get the required info on feature size.
    numChannelsPred = size(YPredictions{ii},3);
    numAnchors = size(anchorBoxMask{ii},2);
    numPredElemsPerAnchors = numChannelsPred/numAnchors;
    allIds = (1:numChannelsPred);

    stride = numPredElemsPerAnchors;
    endIdx = numChannelsPred;

    % X positions.
    startIdx = 1;
    predictions{ii,2} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    xIds = startIdx:stride:endIdx;

    % Y positions.
    startIdx = 2;
    predictions{ii,3} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    yIds = startIdx:stride:endIdx;

    % Width.
    startIdx = 3;
    predictions{ii,4} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    wIds = startIdx:stride:endIdx;

    % Height.
    startIdx = 4;
    predictions{ii,5} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    hIds = startIdx:stride:endIdx;

    % Confidence scores.
    startIdx = 5;
    predictions{ii,1} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    confIds = startIdx:stride:endIdx;

    % Accummulate all the non-class indexes
    nonClassIds = [xIds yIds wIds hIds confIds];

    % Class probabilities. Get the indexes which do not belong to the
    % nonClassIds
    classIdx = setdiff(allIds,nonClassIds);
    predictions{ii,6} = YPredictions{ii}(:,:,classIdx,:);
end
end

function [boxDeltaTarget, objectnessTarget, classTarget, maskTarget, boxErrorScaleTarget] = generateTargets(YPredCellGathered, groundTruth, inputImageSize, anchorBoxes, anchorBoxMask, penaltyThreshold)
% generateTargets creates target array for every prediction element
% x, y, width, height, confidence scores and class probabilities.
boxDeltaTarget = cell(size(YPredCellGathered,1),4);
objectnessTarget = cell(size(YPredCellGathered,1),1);
classTarget = cell(size(YPredCellGathered,1),1);
maskTarget = cell(size(YPredCellGathered,1),3);
boxErrorScaleTarget = cell(size(YPredCellGathered,1),1);

% Normalize the ground truth boxes w.r.t image input size.
gtScale = [inputImageSize(2) inputImageSize(1) inputImageSize(2) inputImageSize(1)];
groundTruth(:,1:4,:,:) = groundTruth(:,1:4,:,:)./gtScale;

for numPred = 1:size(YPredCellGathered,1)
    
    % Select anchor boxes based on anchor box mask indices.
    anchors = anchorBoxes(anchorBoxMask{numPred},:);

    bx = YPredCellGathered{numPred,2};
    by = YPredCellGathered{numPred,3};
    bw = YPredCellGathered{numPred,4};
    bh = YPredCellGathered{numPred,5};
    predClasses = YPredCellGathered{numPred,6};
    
    gridSize = size(bx);
    if numel(gridSize)== 3
        gridSize(4) = 1;
    end
    numClasses = size(predClasses,3)/size(anchors,1);
    
    % Initialize the required variables.
    mask = single(zeros(size(bx)));
    confMask = single(ones(size(bx)));
    classMask = single(zeros(size(predClasses)));
    tx = single(zeros(size(bx)));
    ty = single(zeros(size(by)));
    tw = single(zeros(size(bw)));
    th = single(zeros(size(bh)));
    tconf = single(zeros(size(bx)));
    tclass = single(zeros(size(predClasses)));
    boxErrorScale = single(ones(size(bx)));
    
    % Get the IOU of predictions with groundtruth.
    iou = getMaxIOUPredictedWithGroundTruth(bx,by,bw,bh,groundTruth);
    
    % Do not penalize the predictions which have iou greater than penalty
    % threshold.
    confMask(iou > penaltyThreshold) = 0;
    
    for batch = 1:gridSize(4)
        truthBatch = groundTruth(:,1:5,:,batch);
        truthBatch = truthBatch(all(truthBatch,2),:);
        
        % Get boxes with center as 0.
        gtPred = [0-truthBatch(:,3)/2,0-truthBatch(:,4)/2,truthBatch(:,3),truthBatch(:,4)];
        anchorPrior = [0-anchorBoxes(:,2)/(2*inputImageSize(2)),0-anchorBoxes(:,1)/(2*inputImageSize(1)),anchorBoxes(:,2)/inputImageSize(2),anchorBoxes(:,1)/inputImageSize(1)];
        
        % Get the iou of best matching anchor box.
        overLap = bboxOverlapRatio(gtPred,anchorPrior);
        [~,bestAnchorIdx] = max(overLap,[],2);
        
        % Select gt that are within the mask.
        index = ismember(bestAnchorIdx,anchorBoxMask{numPred});
        truthBatch = truthBatch(index,:);
        bestAnchorIdx = bestAnchorIdx(index,:);
        bestAnchorIdx = bestAnchorIdx - anchorBoxMask{numPred}(1,1) + 1;
        
        if ~isempty(truthBatch)
            % Convert top left position of ground-truth to centre coordinates.
            truthBatch = [truthBatch(:,1)+truthBatch(:,3)./2,truthBatch(:,2)+truthBatch(:,4)./2,truthBatch(:,3),truthBatch(:,4),truthBatch(:,5)];
            
            errorScale = 2 - truthBatch(:,3).*truthBatch(:,4);
            truthBatch = [truthBatch(:,1)*gridSize(2),truthBatch(:,2)*gridSize(1),truthBatch(:,3)*inputImageSize(2),truthBatch(:,4)*inputImageSize(1),truthBatch(:,5)];
            for t = 1:size(truthBatch,1)
                
                % Get the position of ground-truth box in the grid.
                colIdx = ceil(truthBatch(t,1));
                colIdx(colIdx<1) = 1;
                colIdx(colIdx>gridSize(2)) = gridSize(2);
                rowIdx = ceil(truthBatch(t,2));
                rowIdx(rowIdx<1) = 1;
                rowIdx(rowIdx>gridSize(1)) = gridSize(1);
                pos = [rowIdx,colIdx];
                anchorIdx = bestAnchorIdx(t,1);
                
                mask(pos(1,1),pos(1,2),anchorIdx,batch) = 1;
                confMask(pos(1,1),pos(1,2),anchorIdx,batch) = 1;
                
                % Calculate the shift in ground-truth boxes.
                tShiftX = truthBatch(t,1)-pos(1,2)+1;
                tShiftY = truthBatch(t,2)-pos(1,1)+1;
                tShiftW = log(truthBatch(t,3)/anchors(anchorIdx,2));
                tShiftH = log(truthBatch(t,4)/anchors(anchorIdx,1));
                
                % Update the target box.
                tx(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftX;
                ty(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftY;
                tw(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftW;
                th(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftH;
                boxErrorScale(pos(1,1),pos(1,2),anchorIdx,batch) = errorScale(t);
                tconf(rowIdx,colIdx,anchorIdx,batch) = 1;
                classIdx = (numClasses*(anchorIdx-1))+truthBatch(t,5);
                tclass(rowIdx,colIdx,classIdx,batch) = 1;
                classMask(rowIdx,colIdx,(numClasses*(anchorIdx-1))+(1:numClasses),batch) = 1;
            end
        end
    end
    boxDeltaTarget(numPred,:) = [{tx} {ty} {tw} {th}];
    objectnessTarget{numPred,1} = tconf;
    classTarget{numPred,1} = tclass;
    maskTarget(numPred,:) = [{mask} {confMask} {classMask}];
    boxErrorScaleTarget{numPred,:} = boxErrorScale;
end
end

function iou = getMaxIOUPredictedWithGroundTruth(predx,predy,predw,predh,truth)
% getMaxIOUPredictedWithGroundTruth computes the maximum intersection over
%  union scores for every pair of predictions and ground-truth boxes.

[h,w,c,n] = size(predx);
iou = zeros([h w c n],'like',predx);

% For each batch prepare the predictions and ground-truth.
for batchSize = 1:n
    truthBatch = truth(:,1:4,1,batchSize);
    truthBatch = truthBatch(all(truthBatch,2),:);
    predxb = predx(:,:,:,batchSize);
    predyb = predy(:,:,:,batchSize);
    predwb = predw(:,:,:,batchSize);
    predhb = predh(:,:,:,batchSize);
    predb = [predxb(:),predyb(:),predwb(:),predhb(:)];
    
    % Convert from center xy coordinate to topleft xy coordinate.
    predb = [predb(:,1)-predb(:,3)./2, predb(:,2)-predb(:,4)./2, predb(:,3), predb(:,4)];
    
    % Compute and extract the maximum IOU of predictions with ground-truth.
    try 
        overlap = bboxOverlapRatio(predb, truthBatch);
    catch me
        if(any(isnan(predb(:))|isinf(predb(:))))
            error(me.message + " NaN/Inf has been detected during training. Try reducing the learning rate.");
        elseif(any(predb(:,3)<=0 | predb(:,4)<=0))
            error(me.message + " Invalid predictions during training. Try reducing the learning rate.");
        else
            error(me.message + " Invalid groundtruth. Check that your ground truth boxes are not empty and finite, are fully contained within the image boundary, and have positive width and height.");
        end
    end
    
    maxOverlap = max(overlap,[],2);
    iou(:,:,:,batchSize) = reshape(maxOverlap,h,w,c);
end
end


function tiledAnchors = generateTiledAnchors(YPredCell,anchorBoxes,anchorBoxMask)
% Generate tiled anchor offset.
tiledAnchors = cell(size(YPredCell));
for i=1:size(YPredCell,1)
    anchors = anchorBoxes(anchorBoxMask{i}, :);
    [h,w,~,n] = size(YPredCell{i,1});
    [tiledAnchors{i,2}, tiledAnchors{i,1}] = ndgrid(0:h-1,0:w-1,1:size(anchors,1),1:n);
    [~,~,tiledAnchors{i,3}] = ndgrid(0:h-1,0:w-1,anchors(:,2),1:n);
    [~,~,tiledAnchors{i,4}] = ndgrid(0:h-1,0:w-1,anchors(:,1),1:n);
end
end

function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
% Binary cross-entropy loss for objectness score.
objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'ClassificationMode','multilabel'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
end

function [bboxes,scores,labels] = yolov3Detect(net, XTest, networkOutputs, anchors, anchorBoxMask, confidenceThreshold, overlapThreshold, classes)
% The yolov3Detect function detects the bounding boxes, scores, and labels
% in an image.
imageSize = size(XTest, [1,2]);

% To retain 'networkInputSize' in memory and avoid recalculating it,
% declare it as persistent.
persistent networkInputSize

if isempty(networkInputSize)
    networkInputSize = [227 227 3];
end

% Predict and filter the detections based on confidence threshold.
predictions = yolov3Predict(net,XTest,networkOutputs,anchorBoxMask);
predictions = cellfun(@ gather, predictions,'UniformOutput',false);
predictions = cellfun(@ extractdata, predictions, 'UniformOutput', false);
tiledAnchors = generateTiledAnchors(predictions(:,2:5),anchors,anchorBoxMask);
predictions(:,2:5) = applyAnchorBoxOffsets(tiledAnchors, predictions(:,2:5), networkInputSize);

numMiniBatch = size(XTest, 4);

bboxes = cell(numMiniBatch, 1);
scores = cell(numMiniBatch, 1);
labels = cell(numMiniBatch, 1);

for ii = 1:numMiniBatch
    fmap = cellfun(@(x) x(:,:,:,ii), predictions, 'UniformOutput', false);
    [bboxes{ii}, scores{ii}, labels{ii}] = ...
        generateYOLOv3Detections(fmap, confidenceThreshold, overlapThreshold, imageSize, classes);
end

end

function YPredCell = yolov3Predict(net,XTrain,networkOutputs,anchorBoxMask)
% Predict the output of network and extract the confidence, x, y, width,
% height, and class.
YPredictions = cell(size(networkOutputs));
[YPredictions{:}] = predict(net, XTrain);
YPredCell = extractPredictions(YPredictions, anchorBoxMask);

% Apply activation to the predicted cell array.
YPredCell = applyActivations(YPredCell);
end

function [YPredCell, state] = yolov3Forward(net, X,  anchorBoxMask)
% Predict the output of network.
numNetOutputs = numel(net.OutputNames);
networkOuts = cell(numNetOutputs, 1);

% retrieve pruning activations and network outputs
[networkOuts{:}, state] = forward(net, X);

YPredCell = extractPredictions(networkOuts, anchorBoxMask);

% Append predicted width and height to the end as they are required for
% computing the loss.
YPredCell(:,7:8) = YPredCell(:,4:5);

% Apply sigmoid and exponential activation.
YPredCell(:,1:6) = applyActivations(YPredCell(:,1:6));
end

function [YPredCell, state, activations] = yolov3ForwardGate(prunableNet, X,  anchorBoxMask)
% Predict the output of network.
numNetOutputs = numel(prunableNet.OutputNames);
networkOuts = cell(numNetOutputs, 1);

% retrieve outputs of activations and network outputs
[networkOuts{:}, state, activations] = forward(prunableNet, X);

YPredCell = extractPredictions(networkOuts, anchorBoxMask);

% Append predicted width and height to the end as they are required for
% computing the loss.
YPredCell(:,7:8) = YPredCell(:,4:5);

% Apply sigmoid and exponential activation.
YPredCell(:,1:6) = applyActivations(YPredCell(:,1:6));
end

References

[1] Molchanov, Pavlo, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. "Pruning Convolutional Neural Networks for Resource Efficient Inference." Preprint, submitted June 8, 2017. https://arxiv.org/abs/1611.06440.

[2] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. "Importance Estimation for Neural Network Pruning." In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256??64. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01152.

[3] Redmon, Joseph, and Ali Farhadi. "YOLOv3: An Incremental Improvement." Preprint, submitted April 8, 2018. https://arxiv.org/abs/1804.02767.

See Also

Functions