K-Nearest Neighbors (KNN) with sklearn in Python

The popular K-Nearest Neighbors (KNN) algorithm is used for regression and classification in many applications such as recommender systems, image classification, and financial data forecasting. It is the basis of many advanced machine learning techniques (e.g., in information retrieval). There is no doubt that understanding KNN is an important building block of your proficient computer science education.

Watch the article as a video:

K-Nearest Neighbors (KNN) is a robust, simple, and popular machine learning algorithm. It’s relatively easy to implement from scratch while being competitive and performant.

Recap Machine Learning

Machine learning is all about learning a so-called model from a given training data set.

This model can then be used for inference, i.e., predicting output values for potentially new and unseen input data.

A model usually is a high-level abstraction such as a mathematical function inferred from the training data. Most machine learning techniques attempt to find patterns in the data that can be captured and used for generalization and prediction on new input data.

KNN Training

However, KNN follows a quite different path. The simple idea is the following: the whole data set is your model.

Yes, you read that right.

The KNN machine learning model is nothing more than a set of observations. Every single instance of your training data is a part of your model. Training becomes as simple as throwing the training data into a container data structure for later retrieval. There’s no complicated inference phase and hours of distributed GPU processing to extract patterns from the data.

KNN Inference

A great advantage is that you can use the KNN Algorithm for prediction or classification – as you like. You execute the following strategy, given your input vector x.

  • Find the K nearest neighbors of x according to a predefined similarity metric.
  • Aggregate the K nearest neighbors into a single “prediction” or “classification” value. You can use any aggregator function such as average, mean, max, min, etc.

That’s it. Simple, isn’t it?

Check out the following graphic:

Suppose, your company sells homes for clients. It has acquired a large database of customers and experienced house prices.

One day, your client asks how much he can expect to pay for a house with 52 square meters.  You query your KNN “model” and it immediately gives you the response $33,167. And indeed, your client finds a home for $33,489 the same week. How did the KNN system come to this surprisingly accurate prediction?

It simply calculated the K=3 nearest neighbors to the query “D=52 square meters” from the model with regards Euclidean distance. The three nearest neighbors are A, B, and C with prices $34,000, $33,500, and $32,000, respectively. In the final step, the KNN aggregates the three nearest neighbors by calculating the simple average. As K=3 in this example, we denote the model as “3NN”.

Of course, you can vary the similarity functions, the parameter K, and the aggregation method to come up with more sophisticated prediction models.

Another advantage of KNN is that it can be easily adapted as new observations are made. This is not generally true for any machine learning model. A weakness in this regard is obviously that the computational complexity becomes harder and harder, the more points you add. To accommodate for that, you can continuously remove “stale” values from the system.

As I mentioned above, you can also use KNN for classification problems. Instead of averaging over the K nearest neighbors, you can simply use a voting mechanism where each nearest neighbor votes for its class. The class with the most votes wins.

Implementing KNN with SKLearn

## Dependencies
from sklearn.neighbors import KNeighborsRegressor
import numpy as np


## Data (House Size (square meters) / Hous Price ($))
X = np.array([[35, 30000], [45, 45000], [40, 50000],
              [35, 35000], [25, 32500], [40, 40000]])


## One-liner
KNN = KNeighborsRegressor(n_neighbors=3).fit(X[:,0].reshape(-1,1), X[:,1].reshape(-1,1))


## Result & puzzle
res = KNN.predict([[30]])
print(res)

Let’s dive into how to use KNN in Python – in a single line of code.

Take a guess: what’s the output of this code snippet?

Understanding the Code

To help you see the result, let’s plot the housing data from the code:

Can you see the general trend? With growing size of your house, you can expect a linear growth of its market price. Double the square meters and the price will double, too.

In the code, the client requests your price prediction for a house with 30 square meters. What does KNN with K=3 (in short: 3NN) predict?

Beautifully simple, isn’t it? The KNN algorithm finds the three closest houses with respect to house size and averages the predicted house price as the average of the K=3 nearest neighbors.

Thus, the result is $32,500.

Maybe you were confused by the data conversion part within the one-liner. Let me quickly explain what happened here:

## One-liner
KNN = KNeighborsRegressor(n_neighbors=3).fit(X[:,0].reshape(-1,1), X[:,1].reshape(-1,1))

First, we create a new machine learning model called “KNeighborsRegressor”. If you would like to take KNN for classification, you would take the model “KNeighborsClassifier”.

Second, we “train” the model using the fit function with two parameters. The first parameter defines the input (the house size) and the second parameter defines the output (the house price). The shape of both parameters must be so that each observation is an array-like data structure. For example, you wouldn’t use “30” as an input but “[30]”. The reason is that, in general, the input can be multi-dimensional rather than one-dimensional. Therefore, we reshape the input:

print(X[:,0])
"[35 45 40 35 25 40]"

If we would use this 1D NumPy array as an input to the fit() function, the function would not work properly because it expects an array of (array-like) observations – and not an array of integers.

Therefore, we convert the array accordingly using the reshape() function:

print(X[:,0].reshape(-1,1))
"""
[[35]
 [45]
 [40]
 [35]
 [25]
 [40]]
"""

Now, we have six array-like observations. The negative index -1 in the reshape() function call is our “laziness” expression: we want NumPy to determine the number of rows automatically – and only specify how many columns we need (i.e., 1 column).

This article is based on a book chapter of my book Python One-Liners:

Python One-Liners Book: Master the Single Line First!

Python programmers will improve their computer science skills with these useful one-liners.

Python One-Liners

Python One-Liners will teach you how to read and write “one-liners”: concise statements of useful functionality packed into a single line of code. You’ll learn how to systematically unpack and understand any line of Python code, and write eloquent, powerfully compressed Python like an expert.

The book’s five chapters cover (1) tips and tricks, (2) regular expressions, (3) machine learning, (4) core data science topics, and (5) useful algorithms.

Detailed explanations of one-liners introduce key computer science concepts and boost your coding and analytical skills. You’ll learn about advanced Python features such as list comprehension, slicing, lambda functions, regular expressions, map and reduce functions, and slice assignments.

You’ll also learn how to:

  • Leverage data structures to solve real-world problems, like using Boolean indexing to find cities with above-average pollution
  • Use NumPy basics such as array, shape, axis, type, broadcasting, advanced indexing, slicing, sorting, searching, aggregating, and statistics
  • Calculate basic statistics of multidimensional data arrays and the K-Means algorithms for unsupervised learning
  • Create more advanced regular expressions using grouping and named groups, negative lookaheads, escaped characters, whitespaces, character sets (and negative characters sets), and greedy/nongreedy operators
  • Understand a wide range of computer science topics, including anagrams, palindromes, supersets, permutations, factorials, prime numbers, Fibonacci numbers, obfuscation, searching, and algorithmic sorting

By the end of the book, you’ll know how to write Python at its most refined, and create concise, beautiful pieces of “Python art” in merely a single line.

Get your Python One-Liners on Amazon!!

Where to Go From Here?

Understanding algorithms is hard enough.

Why do so many people struggle with algorithms?

Yes, complexity may be an issue from time to time. But in so many cases, the real problem is a lack of your quick and confident understanding of the very basics of code.

Proof: have you ever observed that you can easily understand algorithms visually but not in code?

There is only one solution: master the basics until you don’t have to think about them. Only then can your brain handle the higher-level complexity of algorithms.

To help you achieve this, I invest most of my time and effort in creating the best free Python email course in the web. Join my community of more than 66,000 ambitious Python coders!