Main Content

Shapley Output Functions

A Shapley output function is a function that is called at the end of every iteration of the shapley or fit function. The output function can stop Shapley computations, create plots, save information to your workspace, or perform calculations using query point information.

To use the OutputFcn name-value argument in the call to shapley or fit, write a custom output function with this signature:

stop = outputfcn(x,results,state)

The shapley or fit function passes the variables x, results, and state to your output function. Your output function returns stop, which you set to true to stop the iterations, or false to allow the iterations to continue.

  • x contains the Shapley values for the query point at the current iteration.

  • results is a structure with these fields:

    • Iteration — Current iteration number

    • QueryPointIndex — Index of the query point evaluated at the current iteration

    • TimePerQuery — Time spent computing the Shapley values for the query point at the current iteration

    • Method — Method used to compute the Shapley values for the query point at the current iteration

  • state has these possible values:

    • "init"shapley or fit is about to start iterating.

    • "iter"shapley or fit just finished an iteration.

    • "done"shapley or fit just finished its final iteration.

Note

To specify an output function in the call to shapley or fit, you must specify to perform Shapley computations in series. That is, the UseParallel name-value argument must be set to false.

Stop Shapley Value Computations Early

Train a classification model. Compute the Shapley values for multiple query points. Specify to stop the Shapley computations if they take too much time, and plot the partial results.

Load the CreditRating_Historical data set. The data set contains customer IDs and their financial ratios, industry labels, and credit ratings.

tbl = readtable("CreditRating_Historical.dat");

Display the first three rows of the table.

head(tbl,3)
     ID      WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry    Rating
    _____    _____    _____    _______    ________    _____    ________    ______

    62394    0.013    0.104     0.036      0.447      0.142       3        {'BB'}
    48608    0.232    0.335     0.062      1.969      0.281       8        {'A' }
    42444    0.311    0.367     0.074      1.935      0.366       1        {'A' }

Train a blackbox model of credit ratings by using the fitcecoc function. Use the variables from the second through seventh columns in tbl as the predictor variables. A recommended practice is to specify the class names to set the order of the classes.

blackbox = fitcecoc(tbl,"Rating", ...
    PredictorNames=tbl.Properties.VariableNames(2:7), ...
    CategoricalPredictors="Industry", ...
    ClassNames={'AAA','AA','A','BBB','BB','B','CCC'});

Create a shapley object that explains the predictions for multiple query points. For faster computation, subsample 10% of the observations from tbl with stratification and use the samples to compute the Shapley values. Specify the sampled observations as the query points.

Use the output function earlystop (shown at the end of this example) to stop the Shapley value computations early if the cumulative computation time exceeds 60 seconds. If shapley stops early, the output function creates two new variables in the workspace: totalTime and numQueryPoints.

rng("default") % For reproducibility
c = cvpartition(tbl.Rating,"Holdout",0.10);
sampleTbl = tbl(test(c),:);
explainer = shapley(blackbox,sampleTbl, ...
    queryPoints=sampleTbl,OutputFcn=@earlystop);
Iterations terminated prematurely by user.

Display the total Shapley value computation time.

totalTime
totalTime = 60.3373

Note that the time only slightly exceeds 60 seconds.

Compare the total number of observations in sampleTbl to the number of query points whose Shapley values were computed by shapley.

numObservations = size(sampleTbl,1)
numObservations = 393
numQueryPoints
numQueryPoints = 86

Display a swarm chart of the partial results.

swarmchart(explainer)

Figure contains an axes object. The axes object with title Shapley Summary Plot, xlabel Shapley Value, ylabel Predictor contains 7 objects of type constantline, scatter.

Output Function

The output function earlystop uses the query point computation times in results (results.TimePerQuery) to determine whether to stop Shapley computations early. If the cumulative computation time exceeds 60 seconds, the function stops early. This code creates the earlystop output function.

function stop = earlystop(~,results,state)
persistent totalTime
stop = false;
switch state
    case "init"
        totalTime = 0;
    case "iter"
        totalTime = totalTime + results.TimePerQuery;
        if totalTime > 60
            assignin("base","totalTime",totalTime)
            assignin("base","numQueryPoints",results.Iteration)
            stop = true;
        end
    otherwise
end
end

Find Method Used for Individual Shapley Value Computations

Train an ensemble model that uses tree weak learners with surrogate splits. Compute the Shapley values for multiple query points using predictor data that contains missing values. In this case, the Shapley value computation algorithm might not be the same for all query points. Use an output function to determine the method used to compute the Shapley values for each query point.

Load the fisheriris data set, which contains measurements for 150 irises, and create a table. SepalLength, SepalWidth, PetalLength, and PetalWidth are the predictor variables, and Species is the response variable.

fisheriris = readtable("fisheriris.csv");

Partition the data into two sets. Use 50% of the observations for training and 50% of the observations for computing Shapley values.

rng("default")
c = cvpartition(fisheriris.Species,"Holdout",0.5);
trainTbl = fisheriris(training(c),:);
queryTbl = fisheriris(test(c),:);

For this example, add a missing value to the second observation in queryTbl.

queryTbl{2,4} = NaN;
queryTbl(2,:)
ans=1×5 table
    SepalLength    SepalWidth    PetalLength    PetalWidth     Species  
    ___________    __________    ___________    __________    __________

        4.9            3             1.4           NaN        {'setosa'}

Train a classification ensemble by using the fitcensemble function. Specify to use tree stumps with surrogate splits as the weak learners.

tree = templateTree(Surrogate="on",MaxNumSplits=1);
mdl = fitcensemble(trainTbl,"Species",Learners=tree);

Create a shapley object that explains the predictions for the query points in queryTbl. Use the queryTbl predictor data to compute the Shapley values.

Use the output function methodinfo (shown at the end of this example) to find the Shapley value computation algorithm used for each query point. The function also returns the index of the query point evaluated at each iteration.

explainer = shapley(mdl,queryTbl,QueryPoints=queryTbl, ...
    OutputFcn=@methodinfo)
Warning: Computations might be slow when the tree-based model uses surrogate splits for prediction. In this case, the software uses a mix of 'interventional-kernel' and 'interventional-tree'.
explainer = 
shapley explainer with the following mean absolute Shapley values:

      Predictor        setosa      versicolor    virginica
    _____________    __________    __________    _________

    "SepalLength"      0.056765      0.23593       0.17916
    "SepalWidth"     5.7324e-16    3.861e-16     2.959e-16
    "PetalLength"        4.4249       1.6484        3.1843
    "PetalWidth"         0.1696      0.52159       0.69119


  Properties, Methods


The warning message indicates that shapley might use a mix of the Tree SHAP algorithm with an interventional value function and the Kernel SHAP algorithm with an interventional value function. The Method property of the explainer object reflects this information with the value "interventional-mix".

explainer.Method
ans = 
"interventional-mix"

Create a table containing the method information for each query point.

methodInfoTbl = table(queryPointIndex',methodType', ...
    VariableNames=["QueryPointIndex","Method"])
methodInfoTbl=75×2 table
    QueryPointIndex            Method         
    _______________    _______________________

           1           "interventional-kernel"
           2           "interventional-kernel"
           3           "interventional-kernel"
           4           "interventional-kernel"
           5           "interventional-kernel"
           6           "interventional-kernel"
           7           "interventional-kernel"
           8           "interventional-kernel"
           9           "interventional-kernel"
          10           "interventional-kernel"
          11           "interventional-kernel"
          12           "interventional-kernel"
          13           "interventional-kernel"
          14           "interventional-kernel"
          15           "interventional-kernel"
          16           "interventional-kernel"
      ⋮

unique(methodInfoTbl.Method)
ans = 
"interventional-kernel"

In this example, every query point uses the "interventional-kernel" method.

As a convenience, the output function methodinfo additionally returns the Shapley values for each query point. This information is also available in the ShapleyValues property of explainer.

Find the Shapley values for the second query point. Recall from the table methodInfoTbl that the function evaluated the second query point during the second iteration.

rowNames = explainer.ShapleyValues{:,1};
varNames = ...
    explainer.ShapleyValues.Properties.VariableNames(2:end);
queryPointInfo = array2table(shapleyValues(:,:,2), ...
    RowNames=rowNames,VariableNames=varNames)
queryPointInfo=4×3 table
                     setosa       versicolor     virginica 
                   ___________    ___________    __________

    SepalLength       0.037345       -0.15521       0.11787
    SepalWidth     -7.5788e-16    -5.0265e-16    7.9886e-16
    PetalLength         6.6859        -2.0038       -4.6821
    PetalWidth        0.067022        0.20267       -0.2697

For an example that shows how to find the Shapley values for a specific query point without using an output function, see Investigate One Query Point After Fitting Multiple Query Points.

Output Function

The output function methodinfo records the query point index (results.QueryPointIndex), Shapley values (x), and method (results.Method) at each iteration. The function returns the information to the MATLAB® workspace as the variables queryPointIndex, shapleyValues, and methodType, respectively. This code creates the methodinfo output function.

function stop = methodinfo(x,results,state)
persistent queryPointIndex
persistent shapleyValues
persistent methodType
stop = false;
switch state
    case "init"
        queryPointIndex = [];
        shapleyValues = zeros(4,3,1);
            % Initialize shapleyValues based on predictors and classes
        methodType = "";
    case "iter"
        queryPointIndex(results.Iteration) = results.QueryPointIndex;
        shapleyValues(:,:,results.Iteration) = x;
        methodType(results.Iteration) = results.Method;
    case "done"
        assignin("base","queryPointIndex",queryPointIndex)
        assignin("base","shapleyValues",shapleyValues)
        assignin("base","methodType",methodType)
    otherwise
end
end

See Also

| | | |

Related Topics