Clear Filters
Clear Filters

How does selfAttentionLayer work,implementing validation with brief code?

16 views (last 30 days)
How does selfAttentionLayer work in detail every step of the way, can you simply reproduce its working process based on the paper formula? Thus verifying the selfAttentionLayer it's correctness and consisency.
official description:
A self-attention layer computes single-head or multihead self-attention of its input.
The layer:
  1. Computes the queries, keys, and values from the input
  2. Computes the scaled dot-product attention across heads using the queries, keys, and values
  3. Merges the results from the heads
  4. Performs a linear transformation on the merged result

Accepted Answer

cui,xingxing
cui,xingxing on 11 Jan 2024
Edited: cui,xingxing on 27 Apr 2024
Here I have provided myself a simple code workflow with only 2 dimensions, "CT", to illustrate how each step works.
Note that each variable followed by a comment has a dimension representation.
%% 验证selfAttentionLayer操作计算与自己手算一致性!
XTrain = dlarray(rand(10,20));% CT
numClasses = 4;
numHeads = 6;
queryDims = 48; % N1=48
layers = [inputLayer(size(XTrain),"CT");
selfAttentionLayer(numHeads,queryDims,NumValueChannels=12,OutputSize=15,Name="sa");
layerNormalizationLayer;
fullyConnectedLayer(numClasses);
softmaxLayer];
net = dlnetwork(layers);
% analyzeNetwork(net)
XTrain = dlarray(XTrain,"CT");
[act1,act2] = predict(net,XTrain,Outputs=["input","sa"]);
act1 = extractdata(act1);% CT
act2 = extractdata(act2);% CT
% layer params
layerSA = net.Layers(2);
QWeights = layerSA.QueryWeights; % N1*C
KWeights = layerSA.KeyWeights;% N1*C
VWeights = layerSA.ValueWeights;% N2*C
outputW = layerSA.OutputWeights;% N3*N2
Qbias = layerSA.QueryBias; % N1*1
Kbias = layerSA.KeyBias;% N1*1
Vbias = layerSA.ValueBias; % N2*1
outputB = layerSA.OutputBias;% N3*1
% step1
q = QWeights*act1+Qbias; % N1*T
k = KWeights*act1+Kbias;% N1*T
v = VWeights*act1+Vbias;% N2*T
% step2,multiple heads
numChannelsQPerHeads = size(q,1)/numHeads;% 1*1
numChannelsVPerHeads = size(v,1)/numHeads;% 1*1
attentionM = cell(1,numHeads);
for i = 1:numHeads
idxQRange = numChannelsQPerHeads*(i-1)+1:numChannelsQPerHeads*i;
idxVRange = numChannelsVPerHeads*(i-1)+1:numChannelsVPerHeads*i;
qi = q(idxQRange,:);% diQ*T
ki = k(idxQRange,:);% diQ*T
vi = v(idxVRange,:);% diV*T
% attention
dk = size(qi,1);% 1*1
attentionScores = mysoftmax(ki'*qi./sqrt(dk));% T*T, note matlab interal code use k'*q,not q'*k
attentionM{i} = vi*attentionScores; % diV*T
end
%step3,merge attentionM
attention = cat(1,attentionM{:}); % N2*T,N2 = diV*numHeads
%step4,output linear projection
act_ = outputW*attention+outputB;% N3*T
act2(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
act_(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
I have reproduced its working process in the simplest possible way,hope it help others.
function out = mysoftmax(X,dim)
arguments
X
dim = 1;
end
% X = X-max(X,[],dim); %防止X过大导致取exp的值为Inf
X = exp(X);
out = X./sum(X,dim);
end
-------------------------Off-topic interlude, 2024-------------------------------
I am currently looking for a job in the field of CV algorithm development, based in Shenzhen, Guangdong, China,or a remote support position. I would be very grateful if anyone is willing to offer me a job or make a recommendation. My preliminary resume can be found at: https://cuixing158.github.io/about/ . Thank you!
Email: cuixingxing150@gmail.com

More Answers (0)

Categories

Find more on Parallel and Cloud in Help Center and File Exchange

Products


Release

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!