Optimising Nearest Neighbor Program
Show older comments
Hi Guys,
I am trying to optmimise this code so that it runs in under 10 seconds for N=20k.
It currently takes around 40 seconds to run.
I think that I need to vecotirse some or all of the loops so that the calculations are done at the same time, but I cannot figure out how to do it.
Any help would be much appreciated as we are learning from home with next to no support from the University.
Here is the code.
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
4 Comments
KSSV
on 30 Oct 2020
Read about knnsearch.
Samuel Leeney
on 30 Oct 2020
KSSV
on 30 Oct 2020
Refer this, it might help you: https://in.mathworks.com/matlabcentral/fileexchange/44334-nearest-neighboring-particle-search-using-all-particles-search
Samuel Leeney
on 30 Oct 2020
Accepted Answer
More Answers (3)
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
9 Comments
Samuel Leeney
on 30 Oct 2020
KSSV
on 30 Oct 2020
I have copied what you gave in the question...try using a single &.
Samuel Leeney
on 30 Oct 2020
Edited: Samuel Leeney
on 30 Oct 2020
KSSV
on 30 Oct 2020
Edited the answer.....If it throws error...try using single &.
Samuel Leeney
on 30 Oct 2020
KSSV
on 30 Oct 2020
No way.....this should be definitely faster when compared to given loops.
Samuel Leeney
on 30 Oct 2020
KSSV
on 30 Oct 2020
How did you calculate the time? Check the below:
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
t1 = tic ;
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
t1 = toc(t1) ;
t2 = tic ;
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
t2 = toc(t2) ;
Now check t1, t2 for different size inputs. My bet is always the second code will be faster. I am comparing only the first part off the code.
Samuel Leeney
on 30 Oct 2020
Bruno Luong
on 30 Oct 2020
Edited: Bruno Luong
on 30 Oct 2020
I remove your hanling of offset (not sure what is the purpose), and this is a much faster method using delaunay triangulation:
clear
N=20000;
nd=3;
pos=rand(nd,N);
tic
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b))^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
toc % Elapsed time is 56.414344 seconds.
Find nearest neighbour of the same set of point in 2D or 3D
% INPUT: pos is array of size (nd x N), coordinates of N points in R^nd
tic
T = delaunay(pos.');
p = nchoosek(1:size(T,2),2);
P = T(:,p);
P = reshape(P,[],2);
P = unique(sort(P,2),'rows');
P1 = P(:,1);
P2 = P(:,2);
d2 = sum((pos(:,P2)-pos(:,P1)).^2,1);
A = [P1(:), d2(:), P2(:);
P2(:), d2(:), P1(:)];
A = sortrows(A,[1 2]);
b = [true; diff(A(:,1),1)>0];
A = A(b,:);
nn = A(:,3).'; % index of nearest neighbour
d = sqrt(A(:,2)).'; % corresponding distance
toc % Elapsed time is 0.349317 seconds.
isequal(match,nn) % 1
Now it doesn't help you for your practice, but I still pot it here for future readers who seek for fast method.
4 Comments
Samuel Leeney
on 30 Oct 2020
Bruno Luong
on 30 Oct 2020
Well I don't run on the same computer/data than your. Not sure if you have to bother.
Samuel Leeney
on 30 Oct 2020
Bruno Luong
on 30 Oct 2020
Edited: Bruno Luong
on 30 Oct 2020
I run it 10 times and isequal(match,nn) return TRUE. So the answer match 10 times with random points.
> for k=1:10; benchnntest; end
Elapsed time is 49.127755 seconds.
Elapsed time is 0.516644 seconds.
ans =
logical
1
Elapsed time is 51.189459 seconds.
Elapsed time is 0.459055 seconds.
ans =
logical
1
Elapsed time is 50.431960 seconds.
Elapsed time is 0.465887 seconds.
ans =
logical
1
Elapsed time is 50.426246 seconds.
Elapsed time is 0.454885 seconds.
ans =
logical
1
Elapsed time is 50.651649 seconds.
Elapsed time is 0.567084 seconds.
ans =
logical
1
Elapsed time is 50.889422 seconds.
Elapsed time is 0.461514 seconds.
ans =
logical
1
Elapsed time is 50.678441 seconds.
Elapsed time is 0.491820 seconds.
ans =
logical
1
Elapsed time is 50.476219 seconds.
Elapsed time is 0.451430 seconds.
ans =
logical
1
Elapsed time is 52.659327 seconds.
Elapsed time is 0.443114 seconds.
ans =
logical
1
Elapsed time is 52.004992 seconds.
Elapsed time is 0.459873 seconds.
ans =
logical
1
>>
Image Analyst
on 30 Oct 2020
MATLAB is column major order, which means that the left most indexes go faster because they are adjacent in memory. MATLAB goes down rows first, then moves over to the next column and goes down its rows. So this slow code
for row = 1 : rows
for col = 1 : columns
s(row, col) = whatever; % Col iterates fastest
end
end
will (or may be) be slower than this fast code
for col = 1 : columns
for row = 1 : rows
s(row, col) = whatever; % row iterates fastest
end
end
Note that, in your code, n is your left most index of your arrays, yet you had the n loop as the outer loop, which is the slowest possible to do it. If possible, see if you can move n to an inner loop. I've had luck in the past getting nested loops to speed up doing that.
8 Comments
Samuel Leeney
on 30 Oct 2020
Bruno Luong
on 30 Oct 2020
Edited: Bruno Luong
on 30 Oct 2020
"I think I need to vectorise some of the loops to gain more time, but can't seem to figure this out."
The entire double for-loop
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
can be replaced by
s(1:N+1:end) = Inf;
[~,match] = min(s,[],2);
match = match.'; % row vector
Samuel Leeney
on 30 Oct 2020
Bruno Luong
on 30 Oct 2020
First loop
pos_r = reshape(pos,[nd 1 N]);
s = zeros(N);
for b=1:N
posb = pos_r(:,:,b);
off = (pos<=0.25 & posb>=0.75) - ...
(pos>=0.75 & posb<=0.25);
sb = sum((pos-posb+off).^2,1);
s(:,b) = sb(:);
end
s = sqrt(s);
Image Analyst
on 30 Oct 2020
We're not sure because you removed all your comments (which I'm sure you put in like all professional programmers do) before you posted your code so we don't really know what it's doing unless we spend a lot of time trying to understand it. Time that could be eliminated if you just put the comments back into the code to explain what it's doing. Looks like you're doing some kind of thresholding/masking but not sure.
Samuel Leeney
on 30 Oct 2020
Samuel Leeney
on 30 Oct 2020
Image Analyst
on 30 Oct 2020
Samuel, going by your description, I'd try something like this to find the closest point.
N = 1000;
xyz = rand(N, 3); % Get N randomly located points in 3-D.
for k = 1 : N
% Get the squared distance of point k to every other point in the array.
distancesSquared = ((xyz(k, 1) - xyz(:, 1)) .^2 + ...
(xyz(k, 2) - xyz(:, 2)) .^2 + ...
(xyz(k, 3) - xyz(:, 3)) .^2);
% We don't want to consider the distance of the point to itself, so set any zeros to infinity.
distancesSquared(distancesSquared==0) = inf;
% Find the min value and the index of that min for the other points.
[minDist2, index] = min(distancesSquared);
% Print it out.
fprintf('Point %d at (%.2f, %.2f, %.2f) is closest to point %d at (%.2f, %.2f, %.2f).\n',...
k, xyz(k, 1), xyz(k, 2), xyz(k, 3), index, xyz(index, 1), xyz(index, 2), xyz(index, 3));
end
It prints out stuff like:
Point 1 at (0.70, 0.37, 0.03) is closest to point 312 at (0.70, 0.37, 0.04).
Point 2 at (0.09, 0.71, 0.91) is closest to point 918 at (0.10, 0.72, 0.94).
Point 3 at (0.53, 0.95, 0.47) is closest to point 50 at (0.54, 0.93, 0.43).
etc.
Point 998 at (0.03, 0.67, 0.34) is closest to point 972 at (0.05, 0.59, 0.39).
Point 999 at (0.99, 0.87, 0.16) is closest to point 592 at (0.94, 0.93, 0.16).
Point 1000 at (0.54, 0.34, 0.42) is closest to point 540 at (0.55, 0.33, 0.42).
It doesn't do the stuff about the edges of the square though.
Categories
Find more on Spatial Search in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!