Improve speed of linear interpolation in nested loops
9 views (last 30 days)
Show older comments
Alessandro D
on 24 Jul 2022
Commented: Alessandro D
on 30 Aug 2022
I have to do 1-dimensional linear interpolation many times within 4 nested loops. My X-grid is sorted so I can use interp1q but the code is still slow for my purposes. I managed to do a simple vectorization that eliminates the innermost loop (so I have only 3 loops instead of 4) and it's much faster, but unfortunately still not fast enough for my problem. Any suggestions on how to improve speed? Thanks
I report below a MWE (please, bear in mind that in my real problem the grids are larger)
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
%b_gridp(:,k_c) = linspace(b_min,b_max,nb)'; %EDITED HERE
b_gridp(:,k_c) = linspace(b_min+rand,b_max-rand,nb)';
end
%% Slow, not vectorized code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% This is a faster but not fast enough!
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
err = max(abs(stay_arr-stay_arr2),[],'all')
5 Comments
Accepted Answer
Bruno Luong
on 25 Jul 2022
Edited: Bruno Luong
on 25 Jul 2022
This seems to work
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% Full vectorized code
tic
bgridcommon = b_gridp(:,1);
Y = interp1(bgridcommon,(1:nb)',pol_debt); % nk x nb x nx
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = repmat(pol_kp_ind,[1 1 1 nx]);
K = reshape(K,size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
err = norm(stay_arr2(:)-stay_arr(:),Inf)
4 Comments
Bruno Luong
on 26 Jul 2022
Edited: Bruno Luong
on 26 Jul 2022
Sorry forget my comment above about loop. The bin interval is not the first index. Here is the code corrected that works for variable bin vectors.
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min-rand(),b_max+rand(),nb)';
end
disp('start code')
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% Full vectorized code
tic
K = pol_kp_ind;
bminK = reshape(b_gridp(1,K),size(K));
bmaxK = reshape(b_gridp(nb,K),size(K));
Y = 1 + (nb-1) * (pol_debt - bminK) ./ (bmaxK-bminK);
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = reshape(repmat(K,[1 1 1 nx]),size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
err = norm(stay_arr2(:)-stay_arr(:),Inf)
More Answers (0)
See Also
Categories
Find more on Performance and Memory in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!