Shapley Values for svm classification with 10 Kfold
4 views (last 30 days)
Show older comments
Hi,
I would like to compute Shapley Values with SVM tested with Kfold= 10. Here my code:
load fisheriris
inds = ~strcmp(species,'setosa');
X = meas(inds,3:4);
y = species(inds);
SVMModel = fitcsvm(X,y,Kfold=10);
explainer = shapley(SVMModel);
How I can solve this?
0 Comments
Answers (1)
the cyclist
on 25 Feb 2023
Edited: the cyclist
on 25 Feb 2023
Full disclosure: I am a bit new to this myself.
I believe the reason that your code does not work is that technically, your syntax is creating 10 different SVM models, because of the 10-fold cross-validation. shapley is expecting a single, blackbox model.
I did not see a way to take the output of fitcsvm (with crossvalidation), and somehow get at a single model output that shapley() will accept. I think that may be possible. But, an alternative is to create the individual folds by using the cvpartition function, and then running each of the SVM models "manually". Then, you can run shapley on each of the individual model outputs, and average the Shapley values across the folds.
Here is some code that gives the idea-- created by ChatGPT! You'll need to swap in your actual data, of course.
I realize this is a pain, and there may be another way. (But I also realize that questions like yours often don't get too many answers here, so I wanted to share what I know.)
% Load your dataset and split it into k subsets
k = 10;
cv = cvpartition(numObservations,'KFold',k);
% Train an SVM model on each fold and compute Shapley values
for i = 1:k
% Get the training and test data for this fold
trainData = data(cv.training(i),:);
testData = data(cv.test(i),:);
trainLabels = labels(cv.training(i));
% Train an SVM model on the training data
model = fitcsvm(trainData,trainLabels);
% Compute the Shapley values for this model
shapValues{i} = shapley(model,testData); % I don't think this is quite right for getting the model shapley values. Need to check this.
end
% Average the Shapley values across folds
meanShapValues = mean(cat(3,shapValues{:}),3);
3 Comments
the cyclist
on 25 Feb 2023
Edited: the cyclist
on 25 Feb 2023
I'm really happy to have helped. I do feel that what I wrote (aided by ChatGPT) is still pretty awkward. If you found a more elegant or efficient approach, please share it here.
I assume by "cycle", you mean you used for loops, which I guess is needed over both the folds and the data points.
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!