Main Content

This example shows how to use the locally interpretable model-agnostic explanations (LIME) technique to understand the predictions of a deep neural network classifying tabular data. You can use the LIME technique to understand which predictors are most important to the classification decision of a network.

In this example, you interpret a feature data classification network using LIME. For a specified query observation, LIME generates a synthetic data set whose statistics for each feature match the real data set. This synthetic data set is passed through the deep neural network to obtain a classification, and a simple, interpretable model is fitted. This simple model can be used to understand the importance of the top few features to the classification decision of the network. In training this interpretable model, synthetic observations are weighted by their distance from the query observation, so the explanation is "local" to that observation.

This example uses `lime`

(Statistics and Machine Learning Toolbox) and `fit`

(Statistics and Machine Learning Toolbox) to generate a synthetic data set and fit a simple interpretable model to the synthetic data set. To understand the predictions of a trained image classification neural network, use `imageLIME`

. For more information, see Understand Network Predictions Using LIME.

Load the Fisher iris data set. This data contains 150 observations with four input features representing the parameters of the plant and one categorical response representing the plant species. Each observation is classified as one of the three species: setosa, versicolor, or virginica. Each observation has four measurements: sepal width, sepal length, petal width, and petal length.

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat'); load(filename)

Convert the numeric data to a table.

features = ["Sepal length","Sepal width","Petal length","Petal width"]; predictors = array2table(meas,"VariableNames",features); trueLabels = array2table(categorical(species),"VariableNames","Response");

Create a table of training data whose final column is the response.

data = [predictors trueLabels];

Calculate the number of observations, features, and classes.

numObservations = size(predictors,1); numFeatures = size(predictors,2); numClasses = length(categories(data{:,5}));

Partition the data set into training, validation, and test sets. Set aside 15% of the data for validation and 15% for testing.

Determine the number of observations for each partition. Set the random seed to make the data splitting and CPU training reproducible.

```
rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);
```

Create an array of random indices corresponding to the observations and partition it using the partition sizes.

idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation); idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

Partition the table of data into training, validation, and testing partitions using the indices.

dataTrain = data(idxTrain,:); dataVal = data(idxValidation,:); dataTest = data(idxTest,:);

Create a simple multi-layer perceptron, with a single hidden layer with five neurons and ReLU activations. The feature input layer accepts data containing numeric scalars representing features, such as the Fisher iris data set.

numHiddenUnits = 5; layers = [ featureInputLayer(numFeatures) fullyConnectedLayer(numHiddenUnits) reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];

Train the network using stochastic gradient descent with momentum (SGDM). Set the maximum number of epochs to 30 and use a mini-batch size of 15, as the training data does not contain many observations.

opts = trainingOptions("sgdm", ... "MaxEpochs",30, ... "MiniBatchSize",15, ... "Shuffle","every-epoch", ... "ValidationData",dataVal, ... "ExecutionEnvironment","cpu");

Train the network.

net = trainNetwork(dataTrain,layers,opts);

|======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:00 | 40.00% | 31.82% | 1.3060 | 1.2897 | 0.0100 | | 8 | 50 | 00:00:00 | 86.67% | 90.91% | 0.4223 | 0.3656 | 0.0100 | | 15 | 100 | 00:00:00 | 93.33% | 86.36% | 0.2947 | 0.2927 | 0.0100 | | 22 | 150 | 00:00:00 | 86.67% | 81.82% | 0.2804 | 0.3707 | 0.0100 | | 29 | 200 | 00:00:01 | 86.67% | 90.91% | 0.2268 | 0.2129 | 0.0100 | | 30 | 210 | 00:00:01 | 93.33% | 95.45% | 0.2782 | 0.1666 | 0.0100 | |======================================================================================================================|

Classify observations from the test set using the trained network.

predictedLabels = net.classify(dataTest); trueLabels = dataTest{:,end};

Visualize the results using a confusion matrix.

figure confusionchart(trueLabels,predictedLabels)

The network successfully uses the four plant features to predict the species of the test observations.

Use LIME to understand the importance of each predictor to the classification decisions of the network.

Investigate the two most important predictors for each observation.

numImportantPredictors = 2;

Use `lime`

to create a synthetic data set whose statistics for each feature match the real data set. Create a `lime`

object using a deep learning model `blackbox`

and the predictor data contained in `predictors`

. Use a low `'KernelWidth'`

value so `lime`

uses weights that are focused on the samples near the query point.

blackbox = @(x)classify(net,x); explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

You can use the LIME explainer to understand the most important features to the deep neural network. The function estimates the importance of a feature by using a simple linear model that approximates the neural network in the vicinity of a query observation.

Find the indices of the first two observations in the test data corresponding to the setosa class.

```
trueLabelsTest = dataTest{:,end};
label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);
```

Use the `fit`

function to fit a simple linear model to the first two observations from the specified class.

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors); explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

Plot the results.

figure subplot(2,1,1) plot(explainerObs1); subplot(2,1,2) plot(explainerObs2);

For the setosa class, the most important predictors are a low petal length value and a high sepal width value.

Perform the same analysis for class versicolor.

```
label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
```

For the versicolor class, a high petal length value is important.

Finally, consider the virginica class.

```
label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
```

For the virginica class, a high petal length value and a low sepal width value is important.

The LIME plots suggest that a high petal length value is associated with the versicolor and virginica classes and a low petal length value is associated with the setosa class. You can investigate the results further by exploring the data.

Plot the petal length of each image in the data set.

setosaIdx = ismember(data{:,end},"setosa"); versicolorIdx = ismember(data{:,end},"versicolor"); virginicaIdx = ismember(data{:,end},"virginica"); figure hold on plot(data{setosaIdx,"Petal length"},'.') plot(data{versicolorIdx,"Petal length"},'.') plot(data{virginicaIdx,"Petal length"},'.') hold off xlabel("Observation number") ylabel("Petal length") legend(["setosa","versicolor","virginica"])

The setosa class has much lower petal length values than the other classes, matching the results produced from the `lime`

model.

`fit`

(Statistics and Machine Learning Toolbox) | `lime`

(Statistics and Machine Learning Toolbox) | `trainNetwork`

| `classify`

| `featureInputLayer`

| `imageLIME`