VOOZH about

URL: https://www.geeksforgeeks.org/machine-learning/k-means-clustering-on-the-handwritten-digits-data-using-scikit-learn-in-python/

⇱ K-Means clustering on the handwritten digits data using Scikit Learn in Python - GeeksforGeeks


  • Courses
  • Tutorials
  • Interview Prep

K-Means clustering on the handwritten digits data using Scikit Learn in Python

Last Updated : 9 Apr, 2026

K-means clustering is an unsupervised algorithm that is used in customer segmentation applications. In this algorithm, we try to form clusters within our datasets that are closely related to each other in a high-dimensional space. 

Load the Datasets

Output:

array([[ 0., 0., 5., ..., 0., 0., 0.], [ 0., 0., 0., ..., 10., 0., 0.], [ 0., 0., 0., ..., 16., 9., 0.], ..., [ 0., 0., 1., ..., 6., 0., 0.], [ 0., 0., 2., ..., 12., 0., 0.], [ 0., 0., 10., ..., 12., 1., 0.]])

Each handwritten digit in the data is an array of color values of pixels of its image. For better understanding, let's print how the data of the first digit looks like and then display its's respective image 

Output:

First handwritten digit data:  [ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15.  5.  0.  0.  3.

 15.  2.  0. 11.  8.  0.  0.  4. 12.  0.  0.  8.  8.  0.  0.  5.  8.  0.

  0.  9.  8.  0.  0.  4. 11.  0.  1. 12.  7.  0.  0.  2. 14.  5. 10. 12.

  0.  0.  0.  0.  6. 13. 10.  0.  0.  0.]

👁 Sample image from the dataset
Sample image from the dataset

We scale the data to improve performance, converting pixel values (0–255) to a smaller range (e.g., 0–1 or −1 to 1). Since this is unsupervised learning, a train-test split is not mandatory. We set k = 10 based on prior knowledge that the dataset contains digits from 0–9 (clusters are not predefined).

Output:

[[ 0.         -0.33501649 -0.04308102 ... -1.14664746 -0.5056698

  -0.19600752]

 [ 0.         -0.33501649 -1.09493684 ...  0.54856067 -0.5056698

  -0.19600752]

 [ 0.         -0.33501649 -1.09493684 ...  1.56568555  1.6951369

  -0.19600752]

 ...

 [ 0.         -0.33501649 -0.88456568 ... -0.12952258 -0.5056698

  -0.19600752]

 [ 0.         -0.33501649 -0.67419451 ...  0.8876023  -0.5056698

  -0.19600752]

 [ 0.         -0.33501649  1.00877481 ...  0.8876023  -0.26113572

  -0.19600752]]

[0 1 2 ... 8 9 8]

Defining k-means clustering

Now we define the K-means cluster using the KMeans function from the sklearn module. 

Method 1: Using a Random initial cluster

  • Setting the initial cluster points as random data points by using the 'init' argument.
  • The argument 'n_init' is the number of iterations the k-means clustering should run with different initial clusters chosen at random, in the end, the clustering with the least total variance is considered'
  • The random state is kept to 0 (any number can be given) to fix the same random initial clusters every time the code is run.

Method 2: Using k-means++

It is similar to method-1 however, it is not completely random and chooses the initial clusters far away from each other. Therefore, it should require fewer iterations in finding the clusters when compared to the random initialization.

Model Evaluation

We will use scores like silhouette score,  time taken to reach optimum position, v_measure  and some other important metrics.

We will now use the above helper function to evaluate the performance of our k means algorithm.

Output:

Initial-cluster: random
Time taken: 0.302
Homogeneity: 0.739
Completeness: 0.748
V_measure: 0.744
Adjusted random: 0.666
Adjusted mutual info: 0.741
Silhouette: 0.191

Initial-cluster: random
Time taken: 0.386
Homogeneity: 0.742
Completeness: 0.751
V_measure: 0.747
Adjusted random: 0.669
Adjusted mutual info: 0.744
Silhouette: 0.175

Visualizing the K-means clustering for handwritten data

  • Plotting the k-means cluster using the scatter function provided by the matplotlib module.
  • Reducing the large dataset by using Principal Component Analysis (PCA) and fitting it to the previously defined k-means++ model.
  • Plotting the clusters with different colors, a centroid was marked for each cluster.

Output:

👁 Clusters of the data points
Clusters of the data points

You can download the source code from here.

Comment