Define Custom Metric Object
Note
This topic explains how to define custom deep learning metric
objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics
. You can also specify custom metrics using a function handle. For more
information, see Define Custom Metric Function.
In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.
How To Decide Which Metric Type To Use
If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a
function handle, then you can define your own custom metric object using this topic as a guide.
After you define the custom metric, you can specify the metric as the Metrics
name-value argument in the trainingOptions
function.
To define a custom deep learning metric class, you can use the template in this example, which takes you through these steps:
Name the metric — Give the metric a name so that you can use it in MATLAB®.
Declare the metric properties — Specify the public and private properties of the metric.
Create a constructor function — Specify how to construct the metric and set default values.
Create an initialization function (optional) — Specify how to initialize variables and run validation checks.
Create a reset function — Specify how to reset the metric properties between iterations.
Create an update function — Specify how to update metric properties between iterations.
Create an aggregation function — Specify how to aggregate the metric values across multiple instances of the metric object.
Create an evaluation function — Specify how to calculate the metric value for each iteration.
This example shows how to create a custom false positive rate (FPR) metric. This equation defines the metric:
To see the completed metric class definition, see Completed Metric.
Metric Template
Copy the metric template into a new file in MATLAB. This template gives the structure of a metric class definition. It outlines:
The
properties
block for public metric properties. This block must contain theName
property.The
properties
block for private metric properties. This block is optional.The metric constructor function.
The optional
initialize
function.The required
reset
,update
,aggregate
, andevaluate
functions.
classdef myMetric < deep.Metric properties % (Required) Metric name. Name % Declare public metric properties here. % Any code can access these properties. Include here any properties % that you want to access or edit outside of the class. end properties (Access = private) % (Optional) Metric properties. % Declare private metric properties here. % Only members of the defining class can access these properties. % Include here properties that you do not want to edit outside % the class. end methods function metric = myMetric(args) % Create a myMetric object. % This function must have the same name as the class. % Define metric construction function here. end function metric = initialize(metric,batchY,batchT) % (Optional) Initialize metric. % % Use this function to initialize variables and run validation % checks. % % Inputs: % metric - Metric to initialize % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Initialized metric % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric initialization function here. end function metric = reset(metric) % Reset metric properties. % % Use this function to reset the metric properties between % iterations. % % Input: % metric - Metric containing properties to reset % % Output: % metric - Metric with reset properties % Define metric reset function here. end function metric = update(metric,batchY,batchT) % Update metric properties. % % Use this function to update metric properties that you use to % compute the final metric value. % % Inputs: % metric - Metric containing properties to update % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Metric with updated properties % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric update function here. end function metric = aggregate(metric,metric2) % Aggregate metric properties. % % Use this function to define how to aggregate properties from % multiple instances of the same metric object during parallel % training. % % Inputs: % metric - Metric containing properties to aggregate % metric2 - Metric containing properties to aggregate % % Output: % metric - Metric with aggregated properties % % Define metric aggregation function here. end function val = evaluate(metric) % Evaluate metric properties. % % Use this function to define how to use the metric properties % to compute the final metric value. % % Input: % metric - Metric containing properties to use to % evaluate the metric value % % Output: % val - Evaluated metric value % % To return multiple metric values, replace val with val1,... % valN. % Define metric evaluation function here. end end end
Metric Name
First, give the metric a name. In the first line of the class file, replace the
existing name myMetric
with fprMetric
.
classdef fprMetric < deep.Metric ... end
Next, rename the myMetric
constructor function (the first function
in the methods
section) so that it has the same name as the
metric.
methods function metric = fprMetric(args) ... end ... end
Save Metric
Save the metric class file in a new file with the name
fprMetric
and the .m
extension. The file
name must match the metric name. To use the metric, you must save the file in the
current folder or in a folder on the MATLAB path.
Declare Properties
Declare the metric properties in the property sections. You can
specify attributes in the class definition to customize the behavior of properties for specific
purposes. This template defines two property types by setting their Access
attribute. Use the Access
attribute to control access to specific class
properties.
properties
— Any code can access these properties. This is the default properties block with the default property attributes. By default, theAccess
attribute ispublic
.properties (Access = private)
— Only members of the defining class can access the property.
Declare Public Properties
Declare public properties by listing them in the properties
section. This section must contain the Name
property.
properties % (Required) Metric name. Name end
Declare Private Properties
Declare private properties by listing them in the properties (Access =
private)
section. This metric requires twp properties to evaluate the
value: true negatives (TNs) and false positives (FPs). Only the functions within the
metric class require access to these
values.
properties (Access = private) % Define true negatives (TNs) and false positives (FPs). TrueNegatives FalsePositives end
Create Constructor Function
Create the function that constructs the metric and initializes the metric properties. If the software requires any variables to evaluate the metric value, then these variables must be inputs to the constructor function.
The FPR score metric constructor function requires the Name
,
NetworkOutput
, and Maximize
arguments.
These arguments are optional when you use the constructor to create a metric object.
Specify an args
input to the fprMetric
function
that corresponds to the optional name-value arguments. Add a comment to explain the
syntax of the function.
function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR is minimized. ... end
Next, set the default values for the metric properties. Parse the input arguments
using an arguments
block. Specify the default metric name as
"FPR"
, the default network output as []
, and
the Maximize
property as 0. The metric name appears in plots and
verbose
output.
function metric = fprMetric(args) ... arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = 0 end ... end
Set the properties of the metric.
function metric = fprMetric(args) ... % Set the metric name. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end
View the completed constructor function. With this constructor function, the command
fprMetric(Name="fpr")
creates an FPR metric object with the name
"fpr"
.
function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR is minimized. arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = 1 end % Set the metric name. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end
Create Initialization Function
Create the optional function that initializes variables and runs validation checks.
For this example, the metric does not need the initialize
function,
so you can delete it. For an example of an initialize function, see Initialization Function.
Create Reset Function
Create the function that resets the metric properties. The software calls this function before each iteration. For the FPR score metric, reset the TN and FP values to zero at the start of each iteration.
function metric = reset(metric) % metric = reset(metric) resets the metric properties. metric.TrueNegatives = 0; metric.FalsePositives = 0; end
Create Update Function
Create the function that updates the metric properties that you use to compute the FPR score value. The software calls this function in each training and validation mini-batch.
In the update
function, define these steps:
Find the maximum score for each observation. The maximum score corresponds to the predicted class for each observation.
Find the TN and FP values.
Add the batch TN and FP values to the running total number of TNs and FPs.
function metric = update(metric,batchY,batchT) % metric = update(metric,batchY,batchT) updates the metric % properties. % Find the channel (class) dimension. cDim = finddim(batchY,"C"); % Find the maximum score, which corresponds to the predicted % class. Set the predicted class to 1 and all other classes to 0. batchY = batchY == max(batchY,[],cDim); % Find the TN and FP values for this batch. batchTrueNegatives = sum(~batchY & ~batchT, 2); batchFalsePositives = sum(batchY & ~batchT, 2); % Add the batch values to the running totals and update the metric % properties. metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives; metric.FalsePositives = metric.FalsePositives + batchFalsePositives; end
For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.
When using the metric with
trainnet
and the targets are categorical arrays, if the loss function is"index-crossentropy"
, then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric.When using the metric with
testnet
and the targets are categorical arrays, if the specified metrics include"index-crossentropy"
but do not include"crossentropy"
, then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.
Create Aggregation Function
Create the function that specifies how to combine the metric values and properties
across multiple instances of the metric. For example, the aggregate
function defines how to aggregate properties from multiple instances of the same metric
object during parallel training.
For this example, to combine the TN and FP values, add the values from each metric instance.
function metric = aggregate(metric,metric2) % metric = aggregate(metric,metric2) aggregates the metric % properties across two instances of the metric. metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives; metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives; end
Create Evaluation Function
Create the function that specifies how to compute the metric value in each iteration. This equation defines the FPR metric as:
Implement this equation in the evaluate
function. Find the macro average by taking the average across all the
classes.
function val = evaluate(metric) % val = evaluate(metric) uses the properties in metric to return the % evaluated metric value. % Extract TN and FP values. tn = metric.TrueNegatives; fp = metric.FalsePositives; % Compute the FPR value. val = mean(fp/(fp+tn+eps)); end
As the denominator value of this metric can be zero, add eps
to the
denominator to prevent the metric returning a NaN
value.
Completed Metric
View the completed metric class file.
Note
For more information about when the software calls each function in the class, see Function Call Order.
classdef fprMetric < deep.Metric properties % (Required) Metric name. Name end properties (Access = private) % Define true negatives (TNs) and false positives (FPs). TrueNegatives FalsePositives end methods function metric = fprMetric(args) % metric = fprMetric creates an fprMetric metric object. % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0) % also specifies the optional Name option. By default, % the metric name is "FPR". By default, % the NetworkOutput is [], which corresponds to using all of % the network outputs. Maximize is set to 0 as the optimal value % occurs when the FPR value is minimized. arguments args.Name = "FPR" args.NetworkOutput = [] args.Maximize = false end % Set the metric name value. metric.Name = args.Name; % To support this metric for use with multi-output networks, set % the network output. metric.NetworkOutput = args.NetworkOutput; % To support this metric for early stopping and returning the % best network, set the maximize property. metric.Maximize = args.Maximize; end function metric = reset(metric) % metric = reset(metric) resets the metric properties. metric.TrueNegatives = 0; metric.FalsePositives = 0; end function metric = update(metric,batchY,batchT) % metric = update(metric,batchY,batchT) updates the metric % properties. % Find the channel (class) dimension. cDim = finddim(batchY,"C"); % Find the maximum score, which corresponds to the predicted % class. Set the predicted class to 1 and all other classes to 0. batchY = batchY == max(batchY,[],cDim); % Find the TN and FP values for this batch. batchTrueNegatives = sum(~batchY & ~batchT, 2); batchFalsePositives = sum(batchY & ~batchT, 2); % Add the batch values to the running totals and update the metric % properties. metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives; metric.FalsePositives = metric.FalsePositives + batchFalsePositives; end function metric = aggregate(metric,metric2) % metric = aggregate(metric,metric2) aggregates the metric % properties across two instances of the metric. metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives; metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives; end function val = evaluate(metric) % val = evaluate(metric) uses the properties in metric to return the % evaluated metric value. % Extract TN and FP values. tn = metric.TrueNegatives; fp = metric.FalsePositives; % Compute the FPR value. val = mean(fp./(fp+tn+eps)); end end end
Use Custom Metric During Training
You can use a custom metric in the same way as any other metric in Deep Learning Toolbox™. This section shows how to create and train a network for digit classification and track the FPR value.
Unzip the digit sample data and create an image datastore. The imageDatastore
function automatically labels the images based on folder names.
unzip("DigitsData.zip") imds = imageDatastore("DigitsData", ... IncludeSubfolders=true, ... LabelSource="foldernames");
Use a subset of the data as the validation set.
numTrainingFiles = 750; [imdsTrain,imdsVal] = splitEachLabel(imds,numTrainingFiles,"randomize"); layers = [ ... imageInputLayer([28 28 1]) convolution2dLayer(5,20) reluLayer maxPooling2dLayer(2,Stride=2) fullyConnectedLayer(10) softmaxLayer];
Create an fprMetric
object.
metric = fprMetric(Name="FalsePositiveRate")
metric = fprMetric with properties: Name: "FalsePositiveRate" NetworkOutput: [] Maximize: 0
Specify the FPR metric in the training options. To plot the metric during training, set Plots
to "training-progress"
. To output the values during training, set Verbose
to true
. Return the network that achieves the best FPR value.
options = trainingOptions("adam", ... MaxEpochs=5, ... Metrics=metric, ... ValidationData=imdsVal, ... ValidationFrequency=50, ... Verbose=true, ... Plots="training-progress", ... ObjectiveMetricName="FalsePositiveRate", ... OutputNetwork="best-validation");
Train the network using the trainnet
function. The values for the training and validation sets appear in the plot.
net = trainnet(imdsTrain,layers,"crossentropy",options);
Iteration Epoch TimeElapsed LearnRate TrainingLoss ValidationLoss TrainingFalsePositiveRate ValidationFalsePositiveRate _________ _____ ___________ _________ ____________ ______________ _________________________ ___________________________ 0 0 00:00:05 0.001 13.488 0.10018 1 1 00:00:06 0.001 13.974 0.10322 50 1 00:00:22 0.001 2.7424 2.7448 0.037368 0.038889 100 2 00:00:31 0.001 1.2965 1.2235 0.027008 0.023333 150 3 00:00:37 0.001 0.64661 0.80412 0.013953 0.017867 200 4 00:00:45 0.001 0.18627 0.53273 0.006153 0.012311 250 5 00:00:53 0.001 0.16763 0.49371 0.0060146 0.012267 290 5 00:01:01 0.001 0.25976 0.39347 0.0062093 0.0098222 Training stopped: Max epochs completed
See Also
trainingOptions
| trainnet
| dlnetwork