How does crossval (for k-fold CV) work in MATLAB after training a classifier?

To my knowledge, k-fold CV is a technique for model selection where the data is first divided into k-folds where the data in each fold is stratified. Now, consider the following code:
trainedClassifier = fitcnb(X, Y);
partitionedModel = crossval(trainedClassifier, 'KFold', 10);
accuracy = 1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError');
The above code first trains the data in matrix X as per the class labels in vector Y. The trainedClassifier is then used in the function crossval(). My doubt is very simple. Does this line of code
partitionedModel = crossval(trainedClassifier, 'KFold', 10);
divide the matrix X into ten folds and then trains on 9 folds, testes on the remaining fold and this is repeated 10 times with each fold as test matrix or does it simply use the trainedClassifier that was trained in the previous line on the whole matrix X and then testes on each fold as I can only see that the fitcnb has been used only once. Does the function crossval() works upon it internally? If it doesn't, then the training is being done on the whole data instead of on the 9 folds in each iteration as is defined by cross-validation.
Fellow members of the community, I will be highly obliged if this doubt of mine can be cleared. Thanking you in anticipation.

3 Comments

I Have the same question on this matter. Anyone got any answers please?
I too have the same question. Did anyone find the answer?
I have the same question. Do you have any answer now?

Sign in to comment.

Answers (3)

The answer is that it divides the dataset into 10 folds and trains the model 10 times on 9 folds each time, using the remaining fold as the test set. The only information taken from 'trainedClassifier' are the hyperparameter values, which are used in each of the 10 trainings. 'fitcnb' is not called 10 times, 'ClassificationNaiveBayes.fit' is.

11 Comments

Do you mean this is a complete k-fold cross-validation?
Yes, this is standard k-fold cross validation. I assume that‘s what you mean by “complete”.
The only information taken from 'trainedClassifier' are the hyperparameter values; is this mean that train all the same hyper parameters when it train 10 times with the hyper parameters obtained from full data?
No no, hyperparameters are not obtained from the data at all. Hyperparameters are the settings that you pass to fitcnb, such as 'DistributionNames' or 'Kernel'.
I will try to explain what is happening in the originally posted code:
trainedClassifier = fitcnb(X, Y);
This line fits a single Naive Bayes model to the full dataset {X,Y} using default hyperparameters (settings), such as 'normal' distributions. The result is a ClassificationNaiveBayes object.
partitionedModel = crossval(trainedClassifier, 'KFold', 10);
This line splits the data {X,Y} into 10 folds. It then trains 10 models, each trained on 9 folds. For each of these models, it uses the settings from 'trainedClassifier', namely 'normal' distributions. The result is a ClassificationPartitionedModel, which contains the 10 models that were just trained.
accuracy = 1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError');
This line runs each of the 10 models inside 'partitionedModel' on its own held-out test set, and computes the Loss for each model on its test set. These 10 individual loss values are added up to obtain the full 'kfold' loss. This loss is the classification error rate over the full dataset, because the 10 test folds make up the full dataset. This is the "out of sample" error rate on the full dataset. 'accuracy' is then 1 minus the error rate.
Finally, you would then use 'trainedClassifier' to make predictions on new data. 'accuracy' is now an estimate of the out-of-sample accuracy of 'trainedClassifier'. At this point, 'partitionedModel' can be discarded. Its only purpose was to provide an estimate of the out-of-sample accuracy of trainedClassifier. In fact, you cannot use 'partitionedModel' for prediction. It has no 'predict' method.
Okay, i understood.
So, If i translate the following code as described
gprMdl2 = fitrgp(x,y,'KernelFunction','squaredexponential',...
'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',...
struct('AcquisitionFunctionName','expected-improvement-plus'));
This code contains an option to optimize the hyperparameter automatically. After that,
partitionedModel = crossval(gprMdl2, 'KFold', 10);
This line splits the data {x, y} into 10 folds. It trains 10 models, each trained on 9 folds. For each of these models, is it use the settings from 'gprMdl2' with optimizing hyperparameters? Does each of the 10 models have optimized hyper parameters? If this is correct, the training cases will be slightly different, so the 10 models will have slightly different optimized hyper parameters. However, it feels like it have called fitrgp 10 times.
There is no hyperparameter optimization in your line
gprMdl2 = fitrgp(x,y,'KernelFunction','squaredexponential',...
'KernelParameters',kparams0,'Sigma',sigma0);
gprMdl2 = fitrgp(x,y,'KernelFunction','squaredexponential',...
'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',...
struct('AcquisitionFunctionName','expected-improvement-plus'));
This line runs a Bayesian Optimization to find hyperparameters (settings) for fitrgp that minimizes cross-validation loss. After this line, the optimization is over and will not be run again. gprMdl2 contains the "best" hyperparameters that were found.
partitionedModel = crossval(gprMdl2, 'KFold', 10);
This line then does what I said earlier. It splits the dataset, trains 10 models, and stores them. No optimization is done here.
Yes, I understand now. Thank you.
I've run it directly with the above code, and I've verified that the optimized hyperparameter sigma has the same value in all 10 partitionmodels. However, the other hyperparameter, the kernel length scales, had different values for each of the 10 partiotion models. What happened?
Those are fit as part of the normal fitting process.
Okay, now i fully understand. Thank you.

Sign in to comment.

Asked:

on 7 Mar 2016

Commented:

on 1 Feb 2019

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!