This article is contribute by Antonis Maronikolakis
Objective: We are given a data set containing items, with numeric values, categorized into classes. We are also given a new item, without a class. We want to categorize it into one of the given classes.
Approach: We will use the k Nearest Neighbors algorithm (kNN for short). The algorithm classifies a new item based on its closest neighbors. In other words, the algorithm looks what class of items is closer to the new item, and it classifies the new item to that class.
Algorithm:
Read data from file Given a new item: Compute distances from new item to all other items Pick k closest items Find the most frequent class in these k items Categorize new item in that class
The algorithm will be implemented in Python.
Input: Our input is a text file. The first line holds the feature names and the rest of the lines hold the item information for the features, plus the class the item is categorized in. An example file can be found here. The data set is called Fisher’s Iris, a popular data set in Machine Learning exercises.
We will read the data from the file (stored in “data.txt”) and we will split it by lines.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
f = open('data.txt','r'); | |
lines = f.read().splitlines(); | |
f.close(); |
We will store the feature names (appearing on the first line) in a list:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
features = lines[0].split(',')[:–1]; |
We will save the data into a list, named items. Each item in the list is represented as a dictionary, whose keys are the feature names, plus “Class” to store the class the item is categorized in.
We will also shuffle the items to ensure they are in a random order.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
items = []; | |
for i in range(1,len(lines)): | |
line = lines[i].split(','); | |
itemFeatures = {"Class" : line[–1]}; | |
for j in range(len(features)): | |
#Iterate through the features | |
f = features[j]; #Get the feature at index j | |
#The first item in the line is the class, skip it | |
v = float(line[j]); #Convert value to integer | |
itemFeatures[f] = v; #Add feature to dict | |
items.append(itemFeatures); #Append temp dict to items | |
shuffle(items); |
Classification
For the classification of a new item, we want to calculate the distances between the new item and every item in the data set. The distances in this tutorial are calculated via the generalized Euclidean formula for n dimensions. We will hold the k shortest distances in a list and in the end we will pick the class that is most common in that list.
In the list to hold the nearest neighbors, the elements are 2-tuples: (distance, class).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Classify(nItem, k, Items): | |
#Hold nearest neighbors. | |
#First elementis distance, second class | |
neighbors = []; | |
for item in Items: | |
#Find Euclidean Distance | |
distance = EuclideanDistance(nItem,item); | |
#Update neighbors, | |
#either adding the current item in neighbors or not. | |
neighbors = UpdateNeighbors(neighbors,item,distance,k); | |
#Count the number of each class in neighbors | |
count = CalculateNeighborsClass(neighbors,k); | |
#Find the max in count, | |
#aka the class with the most appearances. | |
return FindMax(count); |
The external functions we need to implement are EuclideanDistance, UpdateNeighbors, CalculateNeighborsClass and FindMax.
EuclideanDistance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def EuclideanDistance(x,y): | |
S = 0; #The sum of the squared differences of the elements | |
for key in x.keys(): | |
S += math.pow(x[key]–y[key],2); | |
return math.sqrt(S); #The square root of the sum |
UpdateNeighbors
Given a new distance to an item, we check if we should add it. First of all, if the list of neighbors has a length less than n we automatically add it to the list, as it is not full yet.
If the list is full (length = n) we check if the distance is longer than all the distances in the list. If yes, we do not add it and we move on.
Otherwise, we remove the longest distance and replace it with the new one, keeping the list sorted in ascending order. That way we can access the largest element by the index -1 (in Python), which gets the last item in the list.
A quick note: We want to keep the list of closest neighbors sorted so that we can more easily update it. To do this, we could implement an Insertion Sort method, but even though it is simple, it takes up quite a bit of space. For the purposes of this tutorial we will go for something easier to write.
Code:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def UpdateNeighbors(neighbors,item,distance,k): | |
if(len(neighbors) < k): | |
#List is not full, add new item and sort | |
neighbors.append([distance,item["Class"]]); | |
neighbors = sorted(neighbors); | |
else: | |
#List is full | |
#Check if new item should be added | |
if(neighbors[–1][0] > distance): | |
#If yes, replace the last element with new item | |
neighbors[–1] = [distance,item["Class"]]; | |
neighbors = sorted(neighbors); | |
return neighbors; |
CalculateNeighborsClass
We want to calculate the class that appears most frequently in the list of closest neighbors. We will use another dictionary, count, whose keys are the class names appearing in the list of neighbors. As we iterate through the neighbors, if a class name is not in the keys, we will add it. Otherwise, we will increment its count by one.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def CalculateNeighborsClass(neighbors,k): | |
count = {}; | |
for i in range(k): | |
if(neighbors[i][1] not in count): | |
#The class at the ith index is not in the count dict. | |
#Initialize it to 1. | |
count[neighbors[i][1]] = 1; | |
else: | |
#Found another item of class c[i]. | |
#Increment its counter. | |
count[neighbors[i][1]] += 1; | |
return count; |
FindMax
This function receives as input the dictionary count we build previously and returns its max.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def FindMax(countList): | |
maximum = –1; #Hold the max | |
classification = ""; #Hold the classification | |
for key in countList.keys(): | |
if(countList[key] > maximum): | |
maximum = countList[key]; | |
classification = key; | |
return classification,maximum; |
Conclusion:
To classify a new item, you must create a dictionary, with keys the feature names and as values the corresponding data and pass it as a parameter to the function Classify.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
newItem = {'PW' : 1.4, 'PL' : 4.7, 'SW' : 3.2, 'SL' : 7.0}; | |
print Classify(newItem,3,items); |
Complete Code:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math; #For pow and sqrt | |
from random import shuffle; | |
def ReadData(fileName): | |
#Read the file, splitting by lines | |
f = open(fileName,'r'); | |
lines = f.read().splitlines(); | |
f.close(); | |
#Split the first line by commas, remove the first element | |
#and save the rest into a list. | |
#The list holds the feature names of the data set. | |
features = lines[0].split(',')[:–1]; | |
items = []; | |
for i in range(1,len(lines)): | |
line = lines[i].split(','); | |
itemFeatures = {"Class" : line[–1]}; | |
for j in range(len(features)): | |
f = features[j]; #Get the feature at index j | |
v = float(line[j]); #Convert feature value to float | |
itemFeatures[f] = v; #Add feature value to dict | |
items.append(itemFeatures); | |
shuffle(items); | |
return items; | |
###_Auxiliary Function_### | |
def EuclideanDistance(x,y): | |
S = 0; #The sum of the squared differences of the elements | |
for key in x.keys(): | |
S += math.pow(x[key]–y[key],2); | |
return math.sqrt(S); #The square root of the sum | |
def CalculateNeighborsClass(neighbors,k): | |
count = {}; | |
for i in range(k): | |
if(neighbors[i][1] not in count): | |
#The class at the ith index is not in the count dict. | |
#Initialize it to 1. | |
count[neighbors[i][1]] = 1; | |
else: | |
#Found another item of class c[i]. Increment its counter. | |
count[neighbors[i][1]] += 1; | |
return count; | |
def FindMax(Dict): | |
maximum = –1; | |
classification = ""; | |
for key in Dict.keys(): | |
if(Dict[key] > maximum): | |
maximum = Dict[key]; | |
classification = key; | |
return classification,maximum; | |
###_Core Functions_### | |
def Classify(nItem, k, Items): | |
#Hold nearest neighbours. First item is distance, second class | |
neighbors = []; | |
for item in Items: | |
#Find Euclidean Distance | |
distance = EuclideanDistance(nItem,item); | |
#Update neighbors, | |
#either adding the current item in neighbors or not. | |
neighbors = UpdateNeighbors(neighbors,item,distance,k); | |
#Count the number of each class in neighbors | |
count = CalculateNeighborsClass(neighbors,k); | |
#Find the max in count, aka the class with the most appearances | |
return FindMax(count); | |
def UpdateNeighbors(neighbors,item,distance,k): | |
if(len(neighbors) < k): | |
#List is not full, add new item and sort | |
neighbors.append([distance,item["Class"]]); | |
neighbors = sorted(neighbors); | |
else: | |
#List is full | |
#Check if new item should be entered | |
if(neighbors[–1][0] > distance): | |
#If yes, replace the last element with new item | |
neighbors[–1] = [distance,item["Class"]]; | |
neighbors = sorted(neighbors); | |
return neighbors; | |
###_Main_### | |
def main(): | |
items = ReadData('data.txt'); | |
newItem = {'PW' : 1.4, 'PL' : 4.7, 'SW' : 3.2, 'SL' : 7.0}; | |
print Classify(newItem,3,items); | |
if __name__ == "__main__": | |
main(); |