6

jott - euclidean_distance

 3 years ago
source link: https://jott.live/markdown/euclidean_distance
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Optimizing Euclidean Distance Queries Mathematically

I saw an interesting post over at lobste.rs about efficiently computing the Euclidean distance between vectors using SIMD intrinsics. I thought it was an interesting problem and decided to take a stab at making it fast on a GPU.

It boils down to the author wanting to find the closest vector in a set of vectors DB to a given query vector q.

But the author notes that square roots are monotonic in the minimum function, so we can simplify our calculation a bit. We instead find

\min_{d \in DB} \sum_i (d_i-q_i)^2

or the element in the database that has the smallest Euclidean distance (squared) from the query vector.

Batching

Immediately we can see that the problem as stated is memory bound. That is, we're going to be able to compute a simple subtraction and multiplication far more quickly than we will be able to pull a vector from RAM. To get around this type of issue, we can batch the queries. This way we can reuse the vectors pulled from the DB for more than one query and hope to utilize the processor more fully.

However it's not super clear how we can batch the queries...

\forall q \in Q : \min_{d \in DB} \sum_i (d_i-q_i)^2

It looks like we should be able to use three for-loops: two outer loops iterating over qs and ds and one inner loop that sums over the weird (d_i-q_i)^2 operation for each element within the vector.

Simplification

But can we do anything simpler than (d_i-q_i)^2? Removing the square-root was easy enough, can we remove anything else? First we expand the terms:

\sum_i d_i^2 - \sum_i 2 d_i q_i + \sum_i q_i^2

Note that \sum_i q_i^2 will be the same for every d and removing it has no impact on the result. Now we have:

\sum_i d_i^2 - \sum_i 2 d_i q_i

The \sum_i d_i^2 term certainly looks unnecessary to calculate every time, so let's precalculate it as d_{bias} = \sum_i d_i^2.

d_{bias} - \sum_i 2 d_i q_i

Better yet, we can divide the whole equation by 2 and remain monotonic in the Euclidean distance. So we'll update our bias term to be

d_{bias} = \frac{1}{2}\sum_i d_i^2

and plug everything back into the original equation:

\forall q \in Q : \min_{d, d_{bias} \in DB} d_{bias} -\sum_i d_i q_i

For mental simplicity we can also negate the problem and find

\forall q \in Q : \max_{d, d_{bias} \in DB} \sum_i d_i q_i + d_{bias}

with the precalculated bias as

d_{bias} = -\frac{1}{2}\sum_i d_i^2

which can be rewritten as

R = Q \cdot DB + DB_{bias}

where we want to find \max_{i} R_{i}.

Matrix multiplication

This final equation shows us that we can view the problem as a matrix multiplication with bias followed by an argmax. Luckily there are plenty of libraries that can do this for us and since the problem is well specified, it'll be easy to implement on GPU.

Here is a full example using floating point numbers with PyTorch, a fast Python library:

import torch
import time

batch = 1500
db_size = 1000000

print("preprocessing the database")
db = torch.randn(144, db_size).to(torch.device('cuda'))
db_bias = (-0.5 * torch.sum(db**2, axis=0))

print("generating a query")
query = torch.randn(batch,144).to(torch.device('cuda'))

print("running the query")
t = time.time()
result = (query @ db) + db_bias
idxs = torch.argmax(result, axis=1)
print(idxs[0]) # to ensure we sync the result back to CPU
print(f"{batch} queries in {time.time() - t:.4f} seconds")

which runs 1500 queries in 70 thousandths of a second on a V100.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK