One of the bottlenecks of any KNN algorithm is calculating the distances between data points. These pairwise distance calculations run in O(n2) time because every combination of two points must be compared for a dataset with n rows. While this doesn’t present too much of an issue for smaller datasets, the runtime rapidly (exponentially you might say) increases as n gets larger.
For example, R’s dist() function takes approximately 3 seconds for a matrix of dimensions 5000×100, and 15 seconds for 10000×100. This function also performs very poorly as the number of columns increases – dist() takes 30 seconds when the matrix is 2000×1000. While there is no way around computing the distance between every pair of rows, there are faster implementations depending on your data types.
If you have purely binary values (0,1) in your data, then simple matrix multiplication can save you seconds or even minutes off the runtime of the distance calculations. As outlined here, the Euclidean distance between binary data points can be quickly calculated using a matrix cross product. So to calculate the Euclidean distance of binary data points in R, you can do:
D <- (!X) %*% t(X)
return(D + t(D))
This works by taking the cross product of NOT X and X. The flipped version of X means that the cross product will count all occurrences where !X is set (1) and t(X) is 0. In case you don’t have matrix multiplication memorized, this graphic from Interactive Mathematics is particularly useful (remember that t(X) means that the columns in the second matrix would actually be the rows from X, so it’s still comparing row vs. row).
So how does this distance calculation compare against R’s dist() function? If you’re a fan of 3D plots, here is one that compares the runtime of dist() to matrix multiplication based on the number of rows and columns.
As we can see from the plot above, the runtime of R’s dist() function scales exponentially as the number of rows and columns increase. It should be noted that for data with fewer than 100 columns, dist() and the matrix cross product have effectively the same runtime. But if you have binary data, odds are you’ll also have a lot of columns because you’ve created a lot of dummy columns based off a categorical column. So for categorical columns that have more than 100 distinct values, a matrix cross product will be significantly faster than R’s built-in dist() function.
So there you have it. Matrix cross products can be much faster than existing functions when it comes to calculating pairwise distances. I imagine this is because matrix operations are extremely optimized and implemented at a low level, while functions such as dist() have to deal with R’s set of data types, storage structures, and overhead.
The disparity in runtimes as the number of columns increases may be due to the fact that R stores data frames and matrices as lists of columns, so accessing a row’s vector of values could require iterating over every column. Meanwhile, taking the cross product of !X and the transpose of X means that the rows of X become columns for t(X), which could further reduce runtime.
Stay tuned for another article where I show how this distance method as well as a host of other speedups, can be used to calculate KNN predictions in a fraction of the time that Caret’s KNN method takes.