In this post, we present a working example of the k-nearest neighbor classifier. Previously we covered the theory behind this algorithm. Please refer Nearest Neighbor Classifier – From Theory to Practice post for further detail.

A Recap to Nearest Neighbor Classifier​

When we utilize KNN for classification purposes, the prediction is the class associated the highest frequency within the K-nearest instances to the test sample. Basically, each neighboring sample upvotes its associated class, and the class with the most top vote will be selected. Straightforwardly, the probability of a sample belonging to a particular class can be calculated as the fraction of the number of samples belonging to that specific class over the number of all samples as follows:

    \[P(C=j) = \frac{Num(C=j)}{\sum_{i}Num(C=i)}\]

It is mistakenly said that try to use an odd number to avoid and even vote! This is simply wrong. Assume we have seven samples including red, green and blue classes. We have 3, 3, and 1 samples for classes red, green and blue, respectively. K-nearest neighbor with k=7 cannot vote on the red or green class! So this is the case for the only k=2 which should be avoided. For parameter k>2, there is no mathematical justification for choosing odd or even numbers.

How to code?

In this phase, we show how to implement KNN using Python and Scikit-learn.

Reading the data

For the first action, we read the data that we desire to use as input. Here, we utilize the iris data set, available here from the UCI Machine Learning and use Pandas to read it.

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
# Determine the column names
col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'categories']
iris = pd.read_csv(url, header=None, names=col_names)
### PLOT ###
category_1 = iris[iris['categories']=='Iris-setosa']
category_2 = iris[iris['categories']=='Iris-virginica']
category_3 = iris[iris['categories']=='Iris-versicolor']
fig, ax = plt.subplots()
ax.plot(category_1['sepal_length'], category_1['sepal_width'], marker='o', linestyle='', ms=12, label='Iris-setosa')
ax.plot(category_2['sepal_length'], category_2['sepal_width'], marker='o', linestyle='', ms=12, label='Iris-virginica')
ax.plot(category_3['sepal_length'], category_3['sepal_width'], marker='o', linestyle='', ms=12, label='Iris-versicolor')
ax.legend()
plt.show()

Plotting the data for the ‘sepal_length’ and ‘sepal_width’ separated by the class will result in the following scatter plot:

The scatter plot of the data distribution for three categories.

Preprocessing

In the preprocessing phase, we perform the following:

  • Creating the data and label vectors.
  • Split the data into train/test.

Both the above operations will be executed by the following source code:

# Creating the dictionary of categories and forming the labels vector.
iris_class = {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica':2}
iris['labels'] = [iris_class[i] for i in iris.categories]
# Creating the data and label vectors. The iris.drop eliminates the irrelevant columns.
X = iris.drop(['categories', 'labels'], axis=1)
Y = iris.labels
# Split the data into train/test.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, random_state=1)

Running the classifier

In this step, we call the classifier by creating and fitting the model and use it to classify the test data.

from sklearn.neighbors import KNeighborsClassifier
## Call the model with k=10 neighbors.
knn = KNeighborsClassifier(n_neighbors=10)
## Fit the model using the training data.
knn.fit(X_train, y_train)
## Test phase.
print(knn.score(X_test, y_test))

The score will be around 97.3%.

Summary

In this post, we implemented a simple KNN algorithm using Scikit-learn. In another post, we explained the theory behind this algorithm. This implementation shows the effectiveness of the KNN algorithm on a relatively small low-dimensional dataset. 

Leave a Comment

Your email address will not be published. Required fields are marked *

Tweet
Share
Pin
Share