KNN From Scratch
I have often been asked to implement a KNN classifier in a few interviews so I thought I’d write down my simplest implementation on MNIST and talk about various ways we can speed it up.
First off, import our python libraries
import mnist
import numpy as np
from scipy.stats import mode
from sklearn.metrics import accuracy_score
Now load up MNIST dataset and do some pre-processing (normalize and flatten).
X, y = mnist.train_images(), mnist.train_labels()
Xt, yt = mnist.test_images(), mnist.test_labels()
# normalize
mu = X.mean()
sigma = X.std()
X = (X - mu) / sigma
Xt = (Xt - mu) / sigma
# flatten
X = X.reshape(-1, 28*28)
Xt = Xt.reshape(-1, 28*28)
Implementing KNN for a single test datapoint is very simple.
- Calculate the distance from the test datapoint to all the training datapoints
- Sort distance values and take top
$k$
training points - Take the mode of the corresponding labels as the output class label.
In our MNIST example, we use euclidean distance or L2/Frobenius-norm as the distance measure.
def knn_single(X, y, Xt, k=5):
dist = np.sum((X - Xt)**2, axis=1)
idx = np.argsort(dist, axis=0)[:k]
votes = y[idx]
y_pred = mode(votes, axis=0)[0][0]
return y_pred
knn_single(X, y, Xt[0], k=5)
Speeding Up Multiple Test Cases
Now this works well for a single test sample, but imagine being given $10K$
test sample at once and asked to classify all of them. Running this knn_single()
function on a loop will take forever. Is there a way to speed it up?
Let’s see if we can vectorize our knn
over multiple test samples. Given m
source samples X
and n
target samples Xt
, we want to find a vectorized way to calculate all pairwise distance between the source and the target samples. The resulting distance can be expressed by a 2d matrix D[n, m]
. Once we have that, we can take the k
most closest samples from the training set for each test sample and take the mode
of the labels for our classification result.
Turns out there is a simple way to do it. We all know the formula, $(a-b)^2 = a^2 - 2ab + b^2$
right? This can be directly applied to calculate the pairwise squared distance between X
and Xt
. The code is shown below.
def knn(X, y, Xt, k=5):
dists = -2 * np.dot(Xt, X.T) + np.sum(X**2, axis=1).reshape(1, -1) \
+ np.sum(Xt**2, axis=1).reshape(-1, 1)
idx = np.argsort(dists, axis=1)[:, :k]
votes = y[idx]
yt_pred = mode(votes, axis=1)[0]
return yt_pred.reshape(-1)
knn(X, y, Xt, k=5)