![]() |
VOOZH | about |
Stratified K-Fold Cross Validation is a technique used for evaluating a model. It is particularly useful for classification problems in which the class labels are not evenly distributed i.e data is imbalanced. It is a enhanced version of K-Fold Cross Validation. Key difference is that it uses stratification which allows original distribution of each class to be maintained across each fold.
For example, if your original dataset had 80% Class 0 and 20% Class 1 your folds would reflect the same proportion of classes in your data. This creates improved and more reliable accuracy metrics.
Random splitting techniques like train_test_split() or regular K-Fold can create problem if they produce imbalanced class proportions in the training and test sets. For example imagine a binary classification dataset with 100 samples where:
Using random sampling in an 80:20 split then all 80 Class 0 in the training set and all 20 Class 1 in the test set. In this case model will never learn to classify Class 1 and would give misleading accuracy.
Now, letβs use stratified sampling on same dataset:
1. Training Set (80 samples):
2. Test Set (20 samples):
This ensures that both training and test sets provide an accurate representation of the full dataset's class proportions and better generalization in the evaluation set.
In real-world classification tasks distribution of observations per class is often imbalanced like in a medical dataset it could be the case that 90% of patients are healthy (Class 0) and 10% have a disease (Class 1). If we randomly split this data there may be some training/test sets that have very few sample or even no samples for the minority class that where Stratified K Fold Cross Validation becomes important.
We will be using statistics and scikit learn module.
Here we will be using breast cancer dataset available in scikit learn.
x = cancer.data: feature/input valuesy = cancer.target: output/class labels (0 or 1)MinMaxScaler(): scales features to a range between 0 and 1fit_transform(x): fits scaler on data and applies transformationHere we will be using logistic regression model.
StratifiedKFold(...): sets up 10-fold stratified cross-validationlst_accu_stratified: empty list to store accuracy scoresskf.split(x, y): splits dataset into stratified train-test indicesx_train_fold, x_test_fold: features for training and testingy_train_fold, y_test_fold: labels for training and testingmax(): highest accuracymin(): lowest accracymean(): average accuracyOutput:
Here we can see that we got a overall accuracy of 96.6% and standard deviation of 0.02 which means our model is working fine.
By using Stratified K-Fold Cross Validation we can ensure that our machine learning model is evaluated fairly and consistently leading to more accurate predictions and better real-world performance.