Using arrayfun to enhance the performance of code in a double for loop
2 views (last 30 days)
Show older comments
I am writing a program to numerically approximate the solution of a two point boundary value problem. The code below goes through a subset of my code and uses double for loops and if statements to execute certain aspects of the code (which should only execute for specific values of n or m). I am curious if arrayfun/ceilfun in Matlab can remove these double for loops. In order to improve performance, I don't want to continuously run the pagemldivide function for every interation of n and m. I know that the Python programming language has np.apply_along_axis and Julia has mapslices. I believe that arrayfun was created to handle situations like these. Below is a subset of the code which I am using (the actual function is around 200 - 300 lines long).
% Subset of Original Code
% Variables for numerical calculation in the double for loop
clear all; close all;
N = 15;
M = 15;
Nx = 16;
Nz = 16;
unm = rand(Nx,Nz+1,M+1,N+1);
alpha = 0.1;
p = rand(Nx,1);
alphap = rand(N+1,1);
gammap = rand(N+1,1);
gamma = 1.5;
k2 = alphap(0+1)^2 + gammap(0+1)^2;
xi_n_m = zeros(Nx,M+1,N+1);
xi_n_m_hat = fft(xi_n_m);
z_max = 1;
z_min = 0;
Dz = rand(Nx+1,Nx+1);
D = (2/(z_max-z_min))*Dz;
D2 = D*D;
D_start = D(1,:);
D_end = D(end,:);
a = 1;
% Double for loop - runs the pagemldivide function M times for all values of N
for n=0:N
for m=0:M
% Form Fnm, Jnm
Fnm = zeros(Nx,Nz+1);
Jnm = zeros(Nx,1);
% If n > 1, update Fnm based on the logic below
if(n>=1)
u_x = dx(unm(:,:,m+1,n-1+1),p);
temp = 2*u_x;
Fnm = Fnm - temp;
u_z = dz(unm(:,:,m+1,n-1+1),Dz,a);
temp = 3*u_z;
Fnm = Fnm - dx(temp,p);
end
% Do similar updates based on the index of n and m
if(m>=1)
u_x = dx(unm(:,:,m-1+1,n+1),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,m-1+1,n+1);
Fnm = Fnm - temp;
end
if(n>=1 && m>=1)
u_x = dx(unm(:,:,m-1+1,n-1+1),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,m-1+1,n-1+1);
Fnm = Fnm - temp;
end
if(n>=2)
u_x = dx(unm(:,:,m+1,n-2+1),p);
temp = 2.*u_x;
Fnm = Fnm - dx(temp,p);
temp = 3.*u_x;
Fnm = Fnm - dz(temp,Dz,a);
temp = 4.*u_x;
Fnm = Fnm - temp;
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = gamma^2.*unm(:,:,m+1,n-2+1);
Fnm = Fnm - temp;
end
if(m>=2)
temp = gamma^2.*unm(:,:,m-2+1,n+1);
Fnm = Fnm - temp;
end
if(n>=1 && m>=2)
temp = gamma^2.*unm(:,:,m-2+1,n-1+1);
Fnm = Fnm - temp;
end
if(n>=2 && m>=1)
u_x = dx(unm(:,:,m-1+1,n-2+1),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,m-1+1,n-2+1);
Fnm = Fnm - temp;
end
if(n>=2 && m>=2)
temp = gamma^2.*unm(:,:,m-2+1,n-2+1);
Fnm = Fnm - temp;
end
% Solve elliptic equation
Fnmhat = fft(Fnm);
Jnmhat = fft(Jnm);
b = Fnmhat.';
alphaalpha = 1.0;
betabeta = 0.0;
gammagamma = k2 - alphap.^2;
d_min = 1.0;
n_min = 0.0;
r_min = xi_n_m_hat(:,m+1,n+1);
d_max = -1i*gammap;
n_max = 1.0;
r_max = Jnmhat;
identy = eye(Nz+1);
% Solve Ax = b for x
A = alphaalpha*D2 + betabeta*D + reshape(gammagamma,1,1,Nx).*identy;
A(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A(end,end,:) = A(end,end,:) + d_min;
A(1,1,:) = A(1,1,:) + reshape(d_max,1,1,Nx);
b(1,:) = r_max;
% Call the pagemldivide function a lot because of the double for
% loops
utilde = pagemldivide(A,reshape(b,[],1,Nx));
Uhat = utilde(:,:).';
% IFFT to get back to physical space
if((n>0)||(m>0))
unm(:,:,m+1,n+1)=ifft(Uhat);
end
end
end
% Functions to take partial derivatives
function [u_x] = dx(u,p)
u_x = ifft((1i*p).*fft(u));
end
function [u_z] = dz(u,Dz,b)
u_z = ((2.0/b)*Dz*u.').';
end
Here is my attempt to remove the for loops using arrayfun:
% New Code
% This might be doable in one call to arrayfun instead of making multiple functions
temp = arrayfun(@(unm) nGreaterThanOne(unm,Fnm,p,Dz,a), unm,'UniformOutput',false);
temp2 = arrayfun(@(unm) mGreaterThanOne(unm,Fnm,p,gamma,alpha), unm,'UniformOutput',false);
temp3 = arrayfun(@(unm) nGreaterThanOneAndmGreaterThanOne(unm,Fnm,p,alpha,gamma), unm,'UniformOutput',false);
temp4 = arrayfun(@(unm) nGreaterThanTwo(unm,Fnm,p,Dz,a,gamma,alpha), unm,'UniformOutput',false);
temp5 = arrayfun(@(unm) mGreaterThanTwo(unm,Fnm,gamma), unm,'UniformOutput',false);
temp6 = arrayfun(@(unm) nGreaterThanOneAndmGreaterThanTwo(unm,Fnm,gamma), unm,'UniformOutput',false);
temp7 = arrayfun(@(unm) nGreaterThanTwoAndmGreaterThanOne(unm,Fnm,gamma,alpha,p), unm,'UniformOutput',false);
temp8 = arrayfun(@(unm) nGreaterThanTwoAndmGreaterThanTwo(unm,Fnm,gamma), unm,'UniformOutput',false);
% Solve Ax = b for x
A = alphaalpha*D2 + betabeta*D + reshape(gammagamma,1,1,Nx).*identy;
A(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A(end,end,:) = A(end,end,:) + d_min;
A(1,1,:) = A(1,1,:) + reshape(d_max,1,1,Nx);
b(1,:) = r_max;
utilde = pagemldivide(A,reshape(b,[],1,Nx));
Uhat = utilde(:,:).';
% I don't know how to remove this double for loop
for n=0:N
for m=0:M
unm(:,:,m+1,n+1)=ifft(Uhat);
end
end
% Functions to take partial derivatives
function [u_x] = dx(u,p)
u_x = ifft((1i*p).*fft(u));
end
function [u_z] = dz(u,Dz,b)
u_z = ((2.0/b)*Dz*u.').';
end
% Functions to pass into arrayfun
% unm is of size (Nx,Nz+1,M+1,N+1) and the for loops go through M and N
function [Fnm] = nGreaterThanOne(unm,Fnm,p,Dz,a)
% True when n > 1 or for unm(:,:,all m, n>1)
u_x = dx(unm(:,:,:,2:end),p);
temp = 2*u_x;
Fnm = Fnm - temp;
% Not sure how to fix this
%u_z = dz(unm(:,:,:,2:end),Dz,a);
u_z = unm(:,:,:,2:end);
temp = 3*u_z;
Fnm = Fnm - dx(temp,p);
end
function [Fnm] = mGreaterThanOne(unm,Fnm,p,gamma,alpha)
% True when n > 1 or for unm(:,:,m>1, all n)
u_x = dx(unm(:,:,2:end,:),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,2:end,:);
Fnm = Fnm - temp;
end
function [Fnm] = nGreaterThanOneAndmGreaterThanOne(unm,Fnm,p,alpha,gamma)
% True when n > 1 and m > 1 or for unm(:,:,m>1,n>1)
u_x = dx(unm(:,:,2:end,2:end),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,2:end,2:end);
Fnm = Fnm - temp;
end
function [Fnm] = nGreaterThanTwo(unm,Fnm,p,Dz,a,gamma,alpha)
% True when n > 2 or for unm(:,:,all m,n>2)
u_x = dx(unm(:,:,:,3:end),p);
temp = 2.*u_x;
Fnm = Fnm - dx(temp,p);
temp = 3.*u_x;
% Fnm = Fnm - dz(temp,Dz,a);
% temp = 4.*u_x;
Fnm = Fnm - temp;
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = gamma^2.*unm(:,:,3:end);
Fnm = Fnm - temp;
end
function [Fnm] = mGreaterThanTwo(unm,Fnm,gamma)
% True when m > 2 or for unm(:,:,m > 2,all n)
temp = gamma^2.*unm(:,:,3:end,:);
Fnm = Fnm - temp;
end
function [Fnm] = nGreaterThanOneAndmGreaterThanTwo(unm,Fnm,gamma)
% True when n > 1 and m > 2 or for unm(:,:,m>2,n>1)
temp = gamma^2.*unm(:,:,3:end,2:end);
Fnm = Fnm - temp;
end
function [Fnm] = nGreaterThanTwoAndmGreaterThanOne(unm,Fnm,gamma,alpha,p)
% True when n > 2 and m > 1 or for unm(:,:,m>1,n>2)
u_x = dx(unm(:,:,2:end,3:end),p);
temp = 2*1i*alpha.*u_x;
Fnm = Fnm - temp;
temp = 2*gamma^2.*unm(:,:,2:end,3:end);
Fnm = Fnm - temp;
end
function [Fnm] = nGreaterThanTwoAndmGreaterThanTwo(unm,Fnm,gamma)
% True when n > 2 and m > 2 or for unm(:,:,m>2,n>2)
temp = gamma^2.*unm(:,:,3:end,3:end);
Fnm = Fnm - temp;
end
Is the recommened approach to remove these for loops through arrayfun (or something similar)? The above code isn't quite correct but is in the spirit of what I would like to do in order to remove the double for loops.
4 Comments
Walter Roberson
on 12 Apr 2023
In order to have a hope that arrayfun would be faster than a for loop, you would need to pass a character vector to arrayfun instead of a function handle, and the character vector would have to be one of several particular options that are no longer documented, and you would have to be using a somewhat recent version of MATLAB.
For example,
arrayfun('isempty', ARRAY)
just might be faster than looping. (Though in this particular case, isempty() of an array element will never be true.. it is more cellfun('isempty', CELL) that could potentially be faster than cellfun(@isempty, CELL) or looping.)
Accepted Answer
Matt J
on 12 Apr 2023
The gpuArray version of arrayfun may help speeds things up, but it has certain restrictions,
0 Comments
More Answers (0)
See Also
Categories
Find more on Data Types in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!