How does the Classification Learner App generate ROC curves for Decision trees and how do I tune the sensitivity/specificity trade-off of the current classifier?

12 views (last 30 days)
When training models using the Classification Learner App, I noticed that MATLAB always selects a very odd operating point on the ROC curve. Usually far off in one of the corners, as seen here:
How does it calculate this curve for decision trees and where can you set the operating point. I.e. how do you set the trade off between sensitivity and specificity?
The curve shown is for a decision tree which was limited to 5 splits and was asked to distinguish between one of two outcomes, it had about 50 predictors and 5000 entries.
Best regards Eduard

Answers (2)

Alex van der Meer
Alex van der Meer on 21 Apr 2018
Edited: Alex van der Meer on 21 Apr 2018

I agree with the answer by Benjamin Hoenes. Some additional information. It differs per classifier whether the scores that are shown in his answer are actualy posterior probabilities between 0 and 1 or some other measure. For support vector machines one has to do an extra step to obtain the posterior probabilities and I know it is also the case for a boosted trees ensemble. I did a project on predicting the outcome of a coma on EEG patient data. I used the perfcurve function to obtain the thresholds that are used to create the ROC curve. I needed 100% specificity so I needed to find the right threshold. Due to lack of time I will just copy paste the main code file from the project which should provide allot of insight on this. I used the generate code function in the classification learner app and then deleted all that I did not need. This code performs 10 runs of 10-fold-cross-validation and plots 10 ROC curves in one figure.

clear all;
close all;
%% Load in data and set parameters
% Select true for 12H set, false for 24H set.
dataIs12H = true;
if dataIs12H
    % Load the 12H dataset
    temp = load('12HsetComa.mat');
    dataTable12H = temp.featuresNEW12hrs;
    % Put the data in like the function would have it called.
    trainingData = dataTable12H;
else
    % Load the 24H dataset
    temp = load('24HsetComa.mat');
    dataTable24H = temp.featuresNEW24hrs;
    % Put the data in like the function would have it called.
    trainingData = dataTable24H;
end
% The different feature liststs
featureList_Corr12H = [false false true false true true false false false true true true false false false false false false false false false true true false false false false false false false true true false true false false true true false false true false false true];
featureList_InfoGain12H = [false false false false true true false false true true true true true false false false false false false false false false true false false false true false true true false true false true false false true false false false true false false false];
indexes = [7,6,11,10,42,12,41,9,30,18,23,40,5,27,14];
featureList_InfoGain24H = [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false];
featureList_InfoGain24H(indexes) = true;
featureList_All = [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false];
featureList_All(:) = 1;
% Select the one to be used
featureList = featureList_InfoGain12H;
% The number of samples, used to get the confidence intervals
n = 300;
% The number of times the crossvalidation is repeated
runs = 10;
runSVM = false;
runKNNw = false;
neigbourCount = 20;
runBoostedTrees = true;
ROCplotTitle = strcat('BOOSTEDT_TEST_ROC Curves for the cross-validated (k=',num2str(neigbourCount),'KNN with 15 features based on info gain');
ROCplotFilename = strcat('BOOSTEDT_TEST_KNNw_',num2str(neigbourCount),'ROC_15f_infoG_g2_24H');
%% Prepare dataTable to format usable for classification
% Extract predictors and response
% This code processes the data into the right shape for training the
% model.
inputTable = trainingData;
predictorNames = {'Nonlinearenergy', 'Activity', 'Mobility', 'Complexity', 'RMSAmplitude', 'kurtosis', 'skewness', 'meanAM', 'stdAM', 'SkewAM', 'KurtAM', 'BSR', 'delta', 'theta', 'alpha', 'spindle', 'beta', 'total', 'delta_tot', 'theta_tot', 'alpha_tot', 'spindle_tot', 'beta_tot', 'alpha_delta', 'theta_delta', 'spindle_delta', 'beta_delta', 'alpha_theta', 'spindle_theta', 'beta_theta', 'fhtife1', 'fhtife2', 'fhtife3', 'fhtife4', 'sef', 'df', 'svd_ent', 'H_spec', 'SE', 'saen', 'absrenyi', 'absshan', 'perm_entr', 'FD'};
predictors = inputTable(:, predictorNames);
response = inputTable.PatientOutcome;
isCategoricalPredictor = [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false];
% Data transformation: Select subset of the features
% This code selects the same subset of features as were used in the app.
includedPredictorNames = predictors.Properties.VariableNames(featureList);
predictors = predictors(:,includedPredictorNames);
isCategoricalPredictor = isCategoricalPredictor(featureList);
%% Classifier dependend 
% Train a classifier ON ALL DATA
% This code specifies all the classifier options and trains the classifier.
if runSVM 
    initialClassifier = fitcsvm(...
    predictors, ...
    response, ...
    'KernelFunction', 'linear', ...
    'PolynomialOrder', [], ...
    'KernelScale', 'auto', ...
    'BoxConstraint', 1, ...
    'Standardize', true, ...
    'ClassNames', [0; 1]);
elseif runKNNw
    initialClassifier = fitcknn(...
    predictors, ...
    response, ...
    'Distance', 'Euclidean', ...
    'Exponent', [], ...
    'NumNeighbors', neigbourCount, ...
    'DistanceWeight', 'SquaredInverse', ...
    'Standardize', true, ...
    'ClassNames', categorical({'0'; '1'}));
elseif runBoostedTrees
    % Train a classifier
    % This code specifies all the classifier options and trains the classifier.
    template = templateTree(...
        'MaxNumSplits', 20);
    initialClassifier = fitcensemble(...
        predictors, ...
        response, ...
        'Method', 'AdaBoostM1', ...
        'NumLearningCycles', 30, ...
        'Learners', template, ...
        'LearnRate', 0.1, ...
        'ClassNames', categorical({'0'; '1'}));
else
    error('You forgot to select a classifier runX = true kind of thing');
end
%% Classifier indepenend, Run evaluations
figure % the figure for the ROC curves
% Perform multiple crossvalidations to increase the confidence of getting
% TPr = 1.
FprAndT_ResultsGoal1 = zeros(runs,2);
TprAndT_ResultsGoal2 = zeros(runs,2);
posClassPosteriors = zeros(length(response),runs);
for i = 1:runs
    % Perform cross-validation
    partitionedModel = crossval(initialClassifier, 'KFold', 10);
      % (NOT AUTO GEN)
      % Try 1 fold at a time
      %mdlSVMpost1 = fitSVMPosterior(partitionedModels(i).Trained{1});
      % Get the posterior probabilities
      % Here, ScoreMdlSVM is also a partioned model but now with a fitted
      % scoretransform function that can be used to calculate the posterior
      % probability of new or test samples.
      if runSVM
          ScoreMdl = fitSVMPosterior(partitionedModel);
          [~,posteriorProbs] = kfoldPredict(ScoreMdl);
      elseif runKNNw
          [~,posteriorProbs] = kfoldPredict(partitionedModel);
      elseif runBoostedTrees
          [~,posteriorProbs] = kfoldPredict(partitionedModel);
      else
          error('You have to choose a classifier type to run at the start');  
      end
      % This function calculates the posterior probabilities for each patient
      % sample, but does it by the classifier that had that sample as test data.
      % The first column of score_svm is P(c=0), the second is P(c=1).
      % Get the ROC curve. c=1 is seen as positive so use index 2 from score_svm:
      posClassPosteriors(:,i) = posteriorProbs(:,2);
      [Xsvm,Ysvm,Tsvm,AUCsvm] = perfcurve(partitionedModel.Y,posClassPosteriors(:,i),1);
      % Actually plot the ROC
      % TODO: Also plot the Points of the T's we use in the ROC's, I dont
      % realy now if this is possible, think about it.
      plot(Xsvm,Ysvm)
      hold on
      %plot(Xnb,Ynb)
      % (For goal 1: )Find the first Threshold, T for which TPr = 1;
      % Then store the FPr and Threshold
      index = find(Ysvm == 1,1);
      FprAndT_ResultsGoal1(i,:) = [Xsvm(index), Tsvm(index)];
      % (For goal 2: )Find the first Threshold, T, for which FPr = 0,05;
      indexes = find(Xsvm <= 0.050); % TODO uitzoeken waarom <= 0.05 niet werkt
      TprAndT_ResultsGoal2(i,:) = [Ysvm(indexes(end)), Tsvm(indexes(end))];    
      i % Show in which run the the algorithm is
  end
  %% Get the worst case,mean and standardD. values for goal 1 and 2
  [goal1ResultsFprAndT,meanT1,worstCaseT1,stdT1] = getMeanStdWorstCase(FprAndT_ResultsGoal1,runs,1);
  [goal2ResultsTprAndT,meanT2,worstCaseT2,stdT2] = getMeanStdWorstCase(TprAndT_ResultsGoal2,runs,2);
  % Info on goal 1
  fprintf('for goal 1 \nOver %d runs: \nThe mean T was %.4f \nThe worst case T was %.4f \nThe stdT was %.4f \n',runs,meanT1,worstCaseT1,stdT1);
  % Info on goal 2
  fprintf('for goal 2 \nOver %d runs: \nThe mean T was %.4f \nThe worst case T was %.4f \nThe stdT was %.4f \n',runs,meanT2,worstCaseT2,stdT2);
%% End the plotting of the ROC curve
legend('run1','run2','run3','run4','run5','run6','run7','run8','run9','run10')
xlabel('False positive rate'); ylabel('True positive rate');
title(ROCplotTitle)
hold off
% Save the ROC plot to a png image
print(strcat('fig/',ROCplotFilename),'-dpng');
%% Do classifictation with the worstCase T for goal 1 and the mean for g2
% This obtains all results per classification, 1 classification per run
totalResults1 = zeros(runs,11);
totalResults2 = zeros(runs,11);
for k = 1:runs
    totalResults1(k,:) = classifyWithT(posClassPosteriors(:,k),response,worstCaseT1);
    totalResults2(k,:) = classifyWithT(posClassPosteriors(:,k),response,meanT2);
end
% Obtain the means and standard deviations from all runs
meanStdResults1 = getMeansStds(totalResults1);
meanStdResults2 = getMeansStds(totalResults2);
%% Get the confidence intervals on the scores
% n is difined at the top of the program
% (on the indexes: acc = 1, specif1 = 2, sensi1 = 3, specif2 = 5, sensi2 = 6 )   
% goal1
[acc1, accL1, accH1] = getDelta95pCI(meanStdResults1(1,1),n);
[specif1, specif1L,specif1H] = getDelta95pCI(meanStdResults1(1,2),n);
[sensi1, sensi1L,sensi1H] = getDelta95pCI(meanStdResults1(1,3),n);
acc1S = meanStdResults1(2,1); specif1S = meanStdResults1(2,2);
sensi1S = meanStdResults1(2,3);
% goal2
[acc2, accL2, accH2] = getDelta95pCI(meanStdResults2(1,1),n);
[specif2, specif2L,specif2H] = getDelta95pCI(meanStdResults2(1,5),n);
[sensi2, sensi2L,sensi2H] = getDelta95pCI(meanStdResults2(1,6),n);
acc2S = meanStdResults2(2,1); specif2S = meanStdResults2(2,5);
sensi2S = meanStdResults2(2,6);
fprintf(strcat('here the means with their confidence intervals and at the end the std.\n', ...
    'acc1\t %.4f\t%.4f\t%.4f\t%.4f\n', ...
    'spec1\t %.4f\t%.4f\t%.4f\t%.4f\n', ...
    'sens1\t %.4f\t%.4f\t%.4f\t%.4f\n', ...
    'acc2\t %.4f\t%.4f\t%.4f\t%.4f\n', ...
    'spec2\t %.4f\t%.4f\t%.4f\t%.4f\n', ...
    'sens2\t %.4f\t%.4f\t%.4f\t%.4f\n'),acc1,accL1,accH1,acc1S,...
    specif1,specif1L,specif1H,specif1S,...
    sensi1,sensi1L,sensi1H,sensi1S,  acc2,accL2,accH2,acc2S, ...
    specif2,specif2L,specif2H,specif2S,...
    sensi2,sensi2L,sensi2H,sensi2S);
fprintf(strcat('The filename of the ROC is: \n',ROCplotFilename,'\n'));

Benjamin Hoenes
Benjamin Hoenes on 20 Apr 2018
Edited: Benjamin Hoenes on 20 Apr 2018

This has been bugging me as well but I think I have it figured it out:

First, it is important to know that

[label,score,cost] = predict(Mdl,X)

does not simply spit out the classification of each observation i in X. It also gives "score" and "cost". Score, as defined in the documentation, is the probability of observation i being a member of class j (so if you have 2 classes, say "on" or "off", there will be a probability for an observation to be on" and another for "off").

In short, if you want to alter the sensitivity or specificity, use predict and then threshold as you see fit: i.e

if score(i,1)>.4
    observation(i) = "on"
end

https://www.mathworks.com/help/stats/compactclassificationdiscriminant.predict.html

You can determine the optimal threshold values using

[X,Y,T,AUC,OPTROCPT,SUBY,SUBYNAMES] = perfcurve(labels,scores,posclass)

https://www.mathworks.com/help/stats/perfcurve.html#bunsogv-OPTROCPT

I hope that helps!

Products

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!