10/11/2013

Pairwise distances in Matlab and Octave

It is always rewarding, to spend some time vectorizing computations in matlab: code runs faster, you are happy to see how all these ugly for-cycles roll up and turn into elegant linear algebra formulas, sometimes you even can exploit hidden logic in your computations(see covariance matrix estimation).

Of course you shouldn't be too fanatic about it: code can come less readable, memory consumption can grow unproportionally: see ndgrid example below, or recall some repmat constructions. I would like to have octave and python -style broadcasting implemented in matlab, but they claim, that bsxfun, arrayfun, and JITA loops implementation are fast enough. I strongly disagree.

Anyway. Suppose you want to compute table of pairwise distances for two set of vectors in $d$-dimensional space: $$(P)_{ij} = \|x_i - y_j\|^2_2, \;\; i=1,\dots,n_1,\; j = 1,\dots,n_2, \text{ and } P\in\mathbb{R}^{n_1\times n_2}.$$ We store $x_i$ and $y_i$ as columns in matrices $X$ and $Y$ respectively: $$X = \left[ x_1, \dots, x_{n_1}\right] \in \mathbb{R}^{d\times n_1},$$ $$Y = \left[ y_1, \dots, y_{n_2}\right] \in \mathbb{R}^{d\times n_2}.$$ We can expand $\ell_2$-norm: $$(P)_{ij} = \|x_i - y_j\|^2_2 = x_i^\top x_i + y_j^\top y_j - 2 x_i^\top y_j,$$ now expression for matrix $P$ can be written using matrix multiplications: $$P = a\mathbb{1}_{(n_2\times 1)}^\top + \mathbb{1}_{(n_1\times 1)} b - 2 X^\top Y,$$ where $a$ and $b$ are vectors containing $\ell_2$ norms for every vector $x_i$ and $y_j$: $a_i = \|x_i\|_2^2, \;\; a\in\mathbb{R}^{n_1\times 1}$, and $b_j = \|y_j\|_2^2, \;\; b\in\mathbb{R}^{n_2\times 1}$.
$\mathbb{1}_{(m\times 1)} = \left[1, \dots, 1\right]^\top \, \in \mathbb{R}^{m\times 1}.$
Finally, this formula can be written in vectorized matlab expression:

a = sum(X.^2, 1);
b = sum(Y.^2, 1);
P = a' * ones(1, N2)  + ones(N1, 1) * b - 2 * (X' * Y); 

This is quite fast and elegant implementation, that has quite the same(can be twice faster, can be twice slower) running time, as matlabs built-in pdist2. And it is much faster and more memory efficient, than ndgrid-based(see below) pairwise distance computation. Take a look at code, comparing 3 different implementations for this problem. Notation is consistent with formulas above.

%generate data
d = 100; % dimensionality
N1 = 500; % number of vectors
N2 = 400; % number of vectors
X = rand(d, N1);
Y = rand(d, N2);

%built-in
tic();
pd = pdist2(X', Y');
fprintf('"built-in":\t%.5f seconds\n', toc());

%ndgrid
tic();
[gx, gy] = ndgrid(1 : N1, 1 : N2);
P_grid = reshape(sqrt(sum( (X(:, gx) - Y(:, gy)).^2, 1)), [N1, N2]);
t = toc();
fprintf('"ndgrid":\t%.5f seconds. Max err: %e\n', t, max(max(abs(P_grid - pd))));

%matrix
tic();
a = sum(X.^2, 1);
b = sum(Y.^2, 1);
%I use abs to get rid of small(1e-15) negative diagoanl values, 
%produced by computational errors
P =  abs(a' * ones(1, N2)  + ones(N1, 1) * b - 2 * (X' * Y)); 
P = sqrt(P); %to be consistent with pdist2
t = toc();
fprintf('"matrix":\t%.5f seconds. Max err: %e\n', t, max(max(abs(P - pd))));
Output on my machine:
"built-in": 0.02721 seconds
"ndgrid": 0.67575 seconds. Max err: 0.000000e+00
"matrix": 0.01662 seconds. Max err: 1.243450e-14