Mean-Shift Clustering Algorithm

Machine LearningMachine LearningBeginner
Practice Now

This tutorial is from open-source community. Access the source code


This lab will guide you through the process of implementing the Mean-Shift Clustering Algorithm using the Scikit-learn library in Python. You will learn how to generate sample data, compute clustering with MeanShift, and plot the results.

VM Tips

After the VM startup is done, click the top left corner to switch to the Notebook tab to access Jupyter Notebook for practice.

Sometimes, you may need to wait a few seconds for Jupyter Notebook to finish loading. The validation of operations cannot be automated because of limitations in Jupyter Notebook.

If you face issues during learning, feel free to ask Labby. Provide feedback after the session, and we will promptly resolve the problem for you.

Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("`Sklearn`")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["`Core Models and Algorithms`"]) sklearn(("`Sklearn`")) -.-> sklearn/UtilitiesandDatasetsGroup(["`Utilities and Datasets`"]) ml(("`Machine Learning`")) -.-> ml/FrameworkandSoftwareGroup(["`Framework and Software`"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/cluster("`Clustering`") sklearn/UtilitiesandDatasetsGroup -.-> sklearn/datasets("`Datasets`") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("`scikit-learn`") subgraph Lab Skills sklearn/cluster -.-> lab-49211{{"`Mean-Shift Clustering Algorithm`"}} sklearn/datasets -.-> lab-49211{{"`Mean-Shift Clustering Algorithm`"}} ml/sklearn -.-> lab-49211{{"`Mean-Shift Clustering Algorithm`"}} end

Import Required Libraries

First, we need to import the required libraries: numpy, sklearn.cluster, sklearn.datasets, and matplotlib.pyplot.

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

Generate Sample Data

Next, we will generate sample data using the make_blobs function from the sklearn.datasets module. We will create a dataset with 10,000 samples and three clusters with centers at [[1, 1], [-1, -1], [1, -1]] and a standard deviation of 0.6.

centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

Compute Clustering with MeanShift

Now we will compute clustering using the Mean-Shift algorithm with the MeanShift class from the sklearn.cluster module. We will use the estimate_bandwidth function to automatically detect the bandwidth parameter, which represents the size of the region of influence for each point.

bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

Plot Results

Finally, we will plot the results using the matplotlib.pyplot library. We will use different colors and markers for each cluster, and we will also plot the cluster centers.


colors = ["#dede00", "#377eb8", "#f781bf"]
markers = ["x", "o", "^"]

for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
plt.title("Estimated number of clusters: %d" % n_clusters_)


In this lab, we learned how to implement the Mean-Shift Clustering Algorithm using the Scikit-learn library in Python. We generated sample data, computed clustering with MeanShift, and plotted the results. By the end of this lab, you should have a better understanding of how to use the Mean-Shift algorithm to cluster data.

Other Machine Learning Tutorials you may like