Main Content

cvpartition

Partition data for cross-validation

Description

cvpartition defines a random partition on a data set. Use this partition to define training and test sets for validating a statistical model using cross-validation. Use training to extract the training indices and test to extract the test indices for cross-validation. Use repartition to define a new random partition of the same type as a given cvpartition object.

Creation

Description

c = cvpartition(n,"KFold",k) returns a cvpartition object c that defines a random nonstratified partition for k-fold cross-validation on n observations. The partition randomly divides the observations into k disjoint subsamples, or folds, each of which has approximately the same number of observations.

example

c = cvpartition(n,"Holdout",p) creates a random nonstratified partition for holdout validation on n observations. This partition divides the observations into a training set and a test, or holdout, set.

example

c = cvpartition(group,"KFold",k) creates a random partition for stratified k-fold cross-validation. Each subsample, or fold, has approximately the same number of observations and contains approximately the same class proportions as in group.

When you specify group as the first input argument, cvpartition discards rows of observations corresponding to missing values in group.

example

c = cvpartition(group,"KFold",k,"Stratify",stratifyOption) returns a cvpartition object c that defines a random partition for k-fold cross-validation. If you specify "Stratify",false, then cvpartition ignores the class information in group and creates a nonstratified random partition. Otherwise, the function implements stratification by default.

c = cvpartition(group,"Holdout",p) randomly partitions observations into a training set and a test, or holdout, set with stratification, using the class information in group. Both the training and test sets have approximately the same class proportions as in group.

example

c = cvpartition(group,"Holdout",p,"Stratify",stratifyOption) returns an object c that defines a random partition into a training set and a test, or holdout, set. If you specify "Stratify",false, then cvpartition creates a nonstratified random partition. Otherwise, the function implements stratification by default.

example

c = cvpartition(n,"Leaveout") creates a random partition for leave-one-out cross-validation on n observations. Leave-one-out is a special case of "KFold" in which the number of folds equals the number of observations.

c = cvpartition(n,"Resubstitution") creates an object c that does not partition the data. Both the training set and the test set contain all of the original n observations.

example

c = cvpartition("CustomPartition",testSets) creates a cvpartition object c that partitions the data based on the test sets indicated in testSets. (since R2023b)

Input Arguments

expand all

Number of observations in the sample data, specified as a positive integer scalar.

Example: 100

Data Types: single | double

Number of folds in the partition, specified as a positive integer scalar. k must be smaller than the total number of observations.

Example: 5

Data Types: single | double

Fraction or number of observations in the test set used for holdout validation, specified as a scalar in the range (0,1) or an integer scalar in the range [1,n), where n is the total number of observations.

  • If p is a scalar in the range (0,1), then cvpartition randomly selects approximately p*n observations for the test set.

  • If p is an integer scalar in the range [1,n), then cvpartition randomly selects p observations for the test set.

Example: 0.2

Example: 50

Data Types: single | double

Grouping variable for stratification, specified as a numeric or logical vector, a categorical, character, or string array, or a cell array of character vectors indicating the class of each observation. cvpartition creates a partition from the observations in group.

Data Types: single | double | logical | categorical | char | string | cell

Indicator for stratification, specified as true or false.

  • If the first input argument to cvpartition is group, then cvpartition implements stratification by default ("Stratify",true). For a nonstratified random partition, specify "Stratify",false.

  • If the first input argument to cvpartition is n, then cvpartition always creates a nonstratified random partition ("Stratify",false). In this case, you cannot specify "Stratify",true.

Data Types: logical

Since R2023b

Custom test sets, specified as a positive integer vector, logical vector, or logical matrix.

  • For holdout validation, specify the test set observations by using a logical vector. A value of 1 (true) indicates that the corresponding observation is in the test set, and value of 0 (false) indicates that the corresponding observation is in the training set.

  • For k-fold cross-validation, specify the test set observations by using an integer vector (with values in the range [1,k]) or a logical matrix with k columns.

    • Integer vector — A value of j indicates that the corresponding observation is in test set j.

    • Logical matrix — The value in row i and column j indicates whether observation i is in test set j.

    Each of the k test sets must contain at least one observation.

  • For leave-one-out cross-validation, specify the test set observations by using an integer vector (with values in the range [1,n]) or an n-by-n logical matrix, where n is the number of observations in the data.

    • Integer vector — A value of j indicates that the corresponding observation is in test set j.

    • Logical matrix — The value in row i and column j indicates whether observation i is in test set j.

Example: "CustomPartition",[true false true false false] indicates a holdout validation scheme, with the first and third observations in the test set.

Example: "CustomPartition",[1 2 2 1 3 3 1 2 3 2] indicates a 3-fold cross-validation scheme, with the first, fourth, and seventh observations in the first test set.

Data Types: single | double | logical

Properties

expand all

Since R2023b

This property is read-only.

Indicator of a custom partition, specified as a logical scalar. The value is 1 (true) when the object was created using custom partitioning. The value is 0 (false) otherwise.

Data Types: logical

This property is read-only.

Number of observations, including observations with missing group values, specified as a positive integer scalar.

Data Types: double

This property is read-only.

Total number of test sets in the partition, specified as the number of folds when the partition type is 'kfold' or 'leaveout', and 1 when the partition type is 'holdout' or 'resubstitution'.

Data Types: double

This property is read-only.

Size of each test set, specified as a positive integer vector when the partition type is 'kfold' or 'leaveout', and a positive integer scalar when the partition type is 'holdout' or 'resubstitution'.

Data Types: double

This property is read-only.

Size of each training set, specified as a positive integer vector when the partition type is 'kfold' or 'leaveout', and a positive integer scalar when the partition type is 'holdout' or 'resubstitution'.

Data Types: double

This property is read-only.

Type of validation partition, specified as 'kfold', 'holdout', 'leaveout', or 'resubstitution'.

Object Functions

repartitionRepartition data for cross-validation
testTest indices for cross-validation
trainingTraining indices for cross-validation

Examples

collapse all

Use the cross-validation misclassification error to estimate how a model will perform on new data.

Load the ionosphere data set. Create a table containing the predictor data X and the response variable Y.

load ionosphere
tbl = array2table(X);
tbl.Y = Y;

Use a random nonstratified partition hpartition to split the data into training data (tblTrain) and a reserved data set (tblNew). Reserve approximately 30 percent of the data.

rng('default') % For reproducibility
n = length(tbl.Y);
hpartition = cvpartition(n,'Holdout',0.3); % Nonstratified partition
idxTrain = training(hpartition);
tblTrain = tbl(idxTrain,:);
idxNew = test(hpartition);
tblNew = tbl(idxNew,:);

Train a support vector machine (SVM) classification model using the training data tblTrain. Calculate the misclassification error and the classification accuracy on the training data.

Mdl = fitcsvm(tblTrain,'Y');
trainError = resubLoss(Mdl)
trainError = 0.0569
trainAccuracy = 1-trainError
trainAccuracy = 0.9431

Typically, the misclassification error on the training data is not a good estimate of how a model will perform on new data because it can underestimate the misclassification rate on new data. A better estimate is the cross-validation error.

Create a partitioned model cvMdl. Compute the 10-fold cross-validation misclassification error and classification accuracy. By default, crossval ensures that the class proportions in each fold remain approximately the same as the class proportions in the response variable tblTrain.Y.

cvMdl = crossval(Mdl); % Performs stratified 10-fold cross-validation
cvtrainError = kfoldLoss(cvMdl)
cvtrainError = 0.1220
cvtrainAccuracy = 1-cvtrainError
cvtrainAccuracy = 0.8780

Notice that the cross-validation error cvtrainError is greater than the resubstitution error trainError.

Classify the new data in tblNew using the trained SVM model. Compare the classification accuracy on the new data to the accuracy estimates trainAccuracy and cvtrainAccuracy.

newError = loss(Mdl,tblNew,'Y');
newAccuracy = 1-newError
newAccuracy = 0.8700

The cross-validation error gives a better estimate of the model performance on new data than the resubstitution error.

Use the same stratified partition for 5-fold cross-validation to compute the misclassification rates of two models.

Load the fisheriris data set. The matrix meas contains flower measurements for 150 different flowers. The variable species lists the species for each flower.

load fisheriris

Create a random partition for stratified 5-fold cross-validation. The training and test sets have approximately the same proportions of flower species as species.

rng('default') % For reproducibility
c = cvpartition(species,'KFold',5);

Create a partitioned discriminant analysis model and a partitioned classification tree model by using c.

discrCVModel = fitcdiscr(meas,species,'CVPartition',c);
treeCVModel = fitctree(meas,species,'CVPartition',c);

Compute the misclassification rates of the two partitioned models.

discrRate = kfoldLoss(discrCVModel)
discrRate = 0.0200
treeRate = kfoldLoss(treeCVModel)
treeRate = 0.0333

The discriminant analysis model has a smaller cross-validation misclassification rate.

Observe the test set (fold) class proportions in a 5-fold nonstratified partition of the fisheriris data. The class proportions differ across the folds.

Load the fisheriris data set. The species variable contains the species name (class) for each flower (observation). Convert species to a categorical variable.

load fisheriris
species = categorical(species);

Find the number of observations in each class. Notice that the three classes occur in equal proportion.

C = categories(species) % Class names
C = 3x1 cell
    {'setosa'    }
    {'versicolor'}
    {'virginica' }

numClasses = size(C,1);
n = countcats(species) % Number of observations in each class
n = 3×1

    50
    50
    50

Create a random nonstratified 5-fold partition.

rng('default') % For reproducibility
cv = cvpartition(species,'KFold',5,'Stratify',false) 
cv = 
K-fold cross validation partition
   NumObservations: 150
       NumTestSets: 5
         TrainSize: 120  120  120  120  120
          TestSize: 30  30  30  30  30
          IsCustom: 0

Show that the three classes do not occur in equal proportion in each of the five test sets, or folds. Use a for-loop to update the nTestData matrix so that each entry nTestData(i,j) corresponds to the number of observations in test set i and class C(j). Create a bar chart from the data in nTestData.

numFolds = cv.NumTestSets;
nTestData = zeros(numFolds,numClasses);
for i = 1:numFolds
    testClasses = species(cv.test(i));
    nCounts = countcats(testClasses); % Number of test set observations in each class
    nTestData(i,:) = nCounts';
end

bar(nTestData)
xlabel('Test Set (Fold)')
ylabel('Number of Observations')
title('Nonstratified Partition')
legend(C)

Figure contains an axes object. The axes object with title Nonstratified Partition, xlabel Test Set (Fold), ylabel Number of Observations contains 3 objects of type bar. These objects represent setosa, versicolor, virginica.

Notice that the class proportions vary in some of the test sets. For example, the first test set contains 8 setosa, 13 versicolor, and 9 virginica flowers, rather than 10 flowers per species. Because cv is a random nonstratified partition of the fisheriris data, the class proportions in each test set (fold) are not guaranteed to be equal to the class proportions in species. That is, the classes do not always occur equally in each test set, as they do in species.

Create a nonstratified holdout partition and a stratified holdout partition for a tall array. For the two holdout sets, compare the number of observations in each class.

When you perform calculations on tall arrays, MATLAB® uses either a parallel pool (the default if you have Parallel Computing Toolbox™) or the local MATLAB session. To run the example using the local MATLAB session when you have Parallel Computing Toolbox, change the global execution environment by using the mapreducer function.

mapreducer(0)

Create a numeric vector of two classes, where class 1 and class 2 occur in the ratio 1:10.

group = [ones(20,1);2*ones(200,1)]
group = 220×1

     1
     1
     1
     1
     1
     1
     1
     1
     1
     1
      ⋮

Create a tall array from group.

tgroup = tall(group)
tgroup =

  220x1 tall double column vector

     1
     1
     1
     1
     1
     1
     1
     1
     :
     :

Holdout is the only cvpartition option that is supported for tall arrays. Create a random nonstratified holdout partition.

CV0 = cvpartition(tgroup,'Holdout',1/4,'Stratify',false)  
CV0 = 
Hold-out cross validation partition
   NumObservations: [1x1 tall]
       NumTestSets: 1
         TrainSize: [1x1 tall]
          TestSize: [1x1 tall]
          IsCustom: 0

Return the result of CV0.test to memory by using the gather function.

testIdx0 = gather(CV0.test);
Evaluating tall expression using the Local MATLAB Session:
- Pass 1 of 1: Completed in 0.42 sec
Evaluation completed in 0.6 sec

Find the number of times each class occurs in the test, or holdout, set.

accumarray(group(testIdx0),1) % Number of observations per class in the holdout set
ans = 2×1

     5
    51

cvpartition produces randomness in the results, so your number of observations in each class can vary from those shown.

Because CV0 is a nonstratified partition, class 1 observations and class 2 observations in the holdout set are not guaranteed to occur in the same ratio as in tgroup. However, because of the inherent randomness in cvpartition, you can sometimes obtain a holdout set in which the classes occur in the same ratio as in tgroup, even though you specify 'Stratify',false. Because the training set is the complement of the holdout set, excluding any NaN or missing observations, you can obtain a similar result for the training set.

Return the result of CV0.training to memory.

trainIdx0 = gather(CV0.training);
Evaluating tall expression using the Local MATLAB Session:
- Pass 1 of 1: Completed in 0.14 sec
Evaluation completed in 0.19 sec

Find the number of times each class occurs in the training set.

accumarray(group(trainIdx0),1) % Number of observations per class in the training set
ans = 2×1

    15
   149

The classes in the nonstratified training set are not guaranteed to occur in the same ratio as in tgroup.

Create a random stratified holdout partition.

CV1 = cvpartition(tgroup,'Holdout',1/4)  
CV1 = 
Hold-out cross validation partition
   NumObservations: [1x1 tall]
       NumTestSets: 1
         TrainSize: [1x1 tall]
          TestSize: [1x1 tall]
          IsCustom: 0

Return the result of CV1.test to memory.

testIdx1 = gather(CV1.test);
Evaluating tall expression using the Local MATLAB Session:
- Pass 1 of 1: Completed in 0.093 sec
Evaluation completed in 0.12 sec

Find the number of times each class occurs in the test, or holdout, set.

accumarray(group(testIdx1),1) % Number of observations per class in the holdout set
ans = 2×1

     5
    51

In the case of the stratified holdout partition, the class ratio in the holdout set and the class ratio in tgroup are the same (1:10).

Create a random partition of data for leave-one-out cross-validation. Compute and compare training set means. A repetition with a significantly different mean suggests the presence of an influential observation.

Create a data set X that contains one value that is much greater than the others.

X = [1 2 3 4 5 6 7 8 9 20]';

Create a cvpartition object that has 10 observations and 10 repetitions of training and test data. For each repetition, cvpartition selects one observation to remove from the training set and reserve for the test set.

c = cvpartition(10,'Leaveout')
c = 
Leave-one-out cross validation partition
   NumObservations: 10
       NumTestSets: 10
         TrainSize: 9  9  9  9  9  9  9  9  9  9
          TestSize: 1  1  1  1  1  1  1  1  1  1
          IsCustom: 0

Apply the leave-one-out partition to X, and take the mean of the training observations for each repetition by using crossval.

values = crossval(@(Xtrain,Xtest)mean(Xtrain),X,'Partition',c)
values = 10×1

    6.5556
    6.4444
    7.0000
    6.3333
    6.6667
    7.1111
    6.8889
    6.7778
    6.2222
    5.0000

View the distribution of the training set means using a box chart (or box plot). The plot displays one outlier.

boxchart(values)

Figure contains an axes object. The axes object contains an object of type boxchart.

Find the repetition corresponding to the outlier value. For that repetition, find the observation in the test set.

[~,repetitionIdx] = min(values)
repetitionIdx = 10
observationIdx = test(c,repetitionIdx);
influentialObservation = X(observationIdx)
influentialObservation = 20

Training sets that contain the observation have substantially different means from the mean of the training set without the observation. This significant change in mean suggests that the value of 20 in X is an influential observation.

Create a cross-validated regression tree by specifying a custom 4-fold cross-validation partition.

Load the carbig data set. Create a table Tbl containing the response variable MPG and the predictor variables Acceleration, Cylinders, and so on.

load carbig
Tbl = table(Acceleration,Cylinders,Displacement, ...
    Horsepower,Model_Year,Weight,Origin,MPG);

Remove observations with missing values. Check the size of the table data after the removal of the observations with missing values.

Tbl = rmmissing(Tbl);
dimensions = size(Tbl)
dimensions = 1×2

   392     8

The resulting table contains 392 observations, where 392/4=98.

Create a custom 4-fold cross-validation partition of the Tbl data. Place the first 98 observations in the first test set, the next 98 observations in the second test set, and so on.

testSet = ones(98,1);
testIndices = [testSet; 2*testSet; ...
    3*testSet; 4*testSet];
c = cvpartition("CustomPartition",testIndices)
c = 
K-fold cross validation partition
   NumObservations: 392
       NumTestSets: 4
         TrainSize: 294  294  294  294
          TestSize: 98  98  98  98
          IsCustom: 1

Train a cross-validated regression tree using the custom partition c. To assess the model performance, compute the cross-validation mean squared error (MSE).

cvMdl = fitrtree(Tbl,"MPG","CVPartition",c);
cvMSE = kfoldLoss(cvMdl)
cvMSE = 21.2223

Tips

  • If you specify group as the first input argument to cvpartition, then the function discards rows of observations corresponding to missing values in group.

  • If you specify group as the first input argument to cvpartition, then the function implements stratification by default. You can specify "Stratify",false to create a nonstratified random partition.

  • You can specify "Stratify",true only when the first input argument to cvpartition is group.

Extended Capabilities

Version History

Introduced in R2008a

expand all