Clear Filters
Clear Filters

DQN weight update using Mean-squared error (MSE)

6 views (last 30 days)
J AI on 17 Nov 2019
Commented: Bha Pr on 2 Apr 2020
I am desiging a DQN from scratch. I am stuck because I am not sure how to update the weights in the policy network. The structure of my neural networks (both target network and policy network) are:
  1. 9 input nodes since I have 9 states, so for example if the input state is 2, input to the nodes are [0 1 0 0 0 0 0 0 0]'
  2. I have 2 hidden layers (I can put whatever number of nodes I want for each of the hidden layers, so let's just assume I am putting 24 and 32 randomly)
  3. 45 output nodes (9 states times 5 actions)
  4. I am using ReLU function as activation functions
This is what I am doing currently:
  1. Choosing action in random or by policy and simulate the environment for that particular step
  2. Push this experience into replay memory
  3. Pull out a batch of 250 from memory.
  4. Passing the batch into the target network to find the target reward , so I end up with
  5. Passing the target reward into the policy network to calculate the MSE for the whole batch under the for loop i = 1:250
  6. Now that I have my MSE, when updating the weights θ of the policy network, the whole structure collapse because of dimension mismatch.
Thank you in advance for any help you provide.
Here is my code for the weight update: Nh = 2, wp= θ =
error_of_hidden_layer{Nh} = wp{Nh+1}'*mean_square_error;
delta{Nh} = (input_of_hidden_layer{Nh}>0).*error_of_hidden_layer{Nh};
for i = Nh:-1:2
error_of_hidden_layer{i-1} = wp{i}'*delta{i};
delta{i-1} = (input_of_hidden_layer{i-1}>0).*error_of_hidden_layer{i-1};
adjustment_of_w{Nh+1} = alpha*mean_square_error*output_of_hidden_layer{Nh}';
for i = Nh:-1:2
adjustment_of_w{i} = alpha*delta{i}.*output_of_hidden_layer{i-1}';
adjustment_of_w{1} = alpha*delta{1}*input_policy_network';
for i = 1:Nh+1
wp{i} = wp{i} + adjustment_of_w{i};

Answers (1)

Raunak Gupta
Raunak Gupta on 21 Nov 2019
In my understanding you are trying to write the backpropagation for the two Neural Network that are present in target and policy network. Since it is mentioned that MSE (mean Squared Error) is used as a Loss function the derivative of the loss for each node at the last layer will turn out to be difference between the target value and the source value. This can be derived mathematically from backpropagation algorithm. So,
delta_last = out - target;
% Here the out is the predicted output in last layer that is 45-by-1
% And target is the actual value that is also 45-by-1
error_of_hidden_layer{Nh} = wp{Nh+1}'*delta_last;
delta in next hidden layer will be exactly same as mentioned in
delta{Nh} = (input_of_hidden_layer{Nh}>0).*error_of_hidden_layer{Nh};
Also, I see in
for i = Nh:-1:2
Since Nh = 2, above loop will not update the error_of_hidden_layer{1}, So I would suggest checking the looping variables for all the loops for avoiding these errors.
Also you may want to take a look at rlDQNAgent for creating DQN Network without worrying to write everything from scratch.
Raunak Gupta
Raunak Gupta on 22 Nov 2019
The Main purpose of the answer was to point out that delta at the last layer will not be sum of squares of error in each node. Rather it is mentioned in the Algorithm itself that it is the difference between predicted values and actual value i.e.
Yi - Q(si,ai)
This is main part of gradient descent, as it update the weight according to gradient of the loss and not the loss itself.
Hope this makes it clear.
Bha Pr
Bha Pr on 2 Apr 2020
I am working with deep reinforcement learning algorithm on a problem in which there are 4 states namely 1,2,3,4 and actions to be taken are 1,2,3 and the reward is calculated on the basis of real-time measurements like temp,pressure, humidity.For eg if at time t1 the state is 1 and if action chosen is 2 , the reward calculated is 7(say) but at t2 again if the state is 1 and action chosen is 2 , the reward calculated is 4(say) it means th reward varies with contexts. How can i implement this problem using deep reinforcement learning algorithm in matlab. Kindly help

Sign in to comment.




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!