You are now following this question
- You will see updates in your followed content feed.
- You may receive emails, depending on your communication preferences.
I'm using VIT transformer in my code. How to convert the output of 1D layer of VIT into 2D with format SSCB?
8 Comments
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
Answers (2)
1 Comment
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
See Also
Categories
Tags
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!An Error Occurred
Unable to complete the action because of changes made to the page. Reload the page to see its updated state.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)