How can I define a custom loss function using trainnet?
56 views (last 30 days)
Show older comments
Hello,
I am trying to define a custom loss function using trainnet. The documentation says:
If the trainnet function does not provide the loss function that you need for your task, then you can specify a custom loss function to the trainnet as a function handle. The function must have the syntax loss = f(Y,T), where Y and T are the predictions and targets, respectively.
However, I am not sure how the predictions and targets are defined here. I am currently using trainnet as follows:
trainedNet = trainnet(dsTrain,layers,"mse",options);
dsTrain is a datastore containing the input and target images for the regression problem. But I would like change the loss to a custom function involving ssim. I would like something similar to the following, although, I know this isn't quite right:
trainedNet = trainnet(dsTrain,layers,@(Y,targets) 1-ssim(Y,targets),options);
I get the following errror message:
Error using trainnet
Value to differentiate is non-scalar. It must be a traced real dlarray scalar.
Thanks!
0 Comments
Answers (1)
Matt J
on 29 Mar 2024
Edited: Matt J
on 29 Mar 2024
If you have multichannel output, the loss function will give you an SSIM per channel, e..g,
loss = @(Y,targets) 1-ssim(Y,targets);
[Y,T]=deal(dlarray(rand(5,4,8),'SSC'));
L=loss(Y,T);
whos L
You need to decide how you want this reduced to a single value.
0 Comments
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!