Can I "merge" label predictions of multiple images into one single prediction when training a CNN?

4 views (last 30 days)
Hi everyone,
So I am training a patch-based CNN (with patches cropped out of images) and for now, let's assume that every minibatch (size = 6) being fed to the network looks like this :
[patch1_image1.png, patch2_image1.png, patch1_image2.png, patch2_image2.png, patch1_image3.png, patch2_image3.png]
where patch1_image1.png and patch2_image1.png belong to image 1,
patch1_image2.png and patch2_image2.png belong to image 2, etc.
My goal is to average the softmax scores of all the patches belonging to the same image (i.e. average the scores of patch1_image1.png and patch2_image1.png together, and so on and so forth) and get a new label prediction based on this new softmax average. I was able to do this "manually" after training the network, which gave really promising results. However, I would like to implement this during the training process and get a new "fused" prediction at each iteration. Also, I don't know if I can apply this to my validation images too. Any thoughts/comments? Should I add a new average pooling layer right before the last output layer or is there another way to do this?
Thank you very much! :)
  2 Comments
Aditya Patil
Aditya Patil on 17 Nov 2020
Do you want the output of a layer to be the average? That's not possible as output of the model at any iteration will depend on the input, which is a single image. If instead you want to average the output of the model for different forward passes, you should continue to keep it separate, as it's not part of the training.

Sign in to comment.

Accepted Answer

Madhav Thakker
Madhav Thakker on 23 Nov 2020
+1 to Aditya Patil's comment.
To add more information, average pooling will only work on the output value for current iteration, which will be a single image.
What you want is to take the average of different forward passes, one workaround for this could be to use different datastores for different patches and combine them. You might want to look at custom training loop for training procedure.
Hope this helps.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!