Clear Filters
Clear Filters

Importing pre-trained recurrent network to reinforcement learning agent

11 views (last 30 days)
Are pre-trained recurrent networks re-initialized when used in agents for reinforment learning? If so, how can it be avoided?
I am importing a LSTM network trained using supervised training as the actor for a PPO agent. When simulating without training the reward is fine, however If the agent is trained the reward falls as if no pre-trained network was used. I would expect the reward to be similar or higher after training so presumably the network is being re-initialized, is there a way around it?
% Load actor
actorNetwork = net.Layers;
actorOpts = rlRepresentationOptions('LearnRate',learnRate);
actor = rlStochasticActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'input'},actorOpts);
% Create critic
criticNetwork = [sequenceInputLayer(numObs,"Name","input")
criticOpts = rlRepresentationOptions('LearnRate',learnRate);
critic = rlValueRepresentation(criticNetwork,obsInfo,'Observation',{'input'},criticOpts);
% Create agent
agentOpts = rlPPOAgentOptions('ExperienceHorizon',expHorizon, 'MiniBatchSize',miniBatchSz, 'NumEpoch',nEpoch, 'ClipFactor', 0.1);
agent = rlPPOAgent(actor,critic,agentOpts);
% Train agent
trainOpts = rlTrainingOptions('MaxEpisodes',episodes, 'MaxStepsPerEpisode',episodeSteps, ...
'Verbose',false, 'Plots','training-progress', ...
'StopTrainingCriteria', 'AverageReward', ...
% Run training
trainingStats = train(agent,env,trainOpts);
% Simulate
simOptions = rlSimulationOptions('MaxSteps',2000);
experience = sim(env,agent,simOptions);

Accepted Answer

Ryan Comeau
Ryan Comeau on 29 May 2020
So, transfer learning does not work the same in RL as it does in DL. In DL, there are no environment physics that need to be understood. Recall that neural networks are really just non-linear curve fitting tools. In DL the way transfer learning works, is you take a pre-trained feature extraction network. This learns which shapes are useful(lines, circles and so on). You then add some of your own images to the mix and obtain some curve fitting results.
In MATLAB's current RL framework, we are not extracting information from images using a CNN, we are supplying observations as a vector. This means a transfer learning will not bring any usefulness to you. As well, the transfer learning cannot know the physics of the enviroment that you've made. It will not understand what to do if you halfed gravity for example(because gravity is not observable to the actor). So it has no way of being useful for you.
Hope this helps,
  1 Comment
Javier Maruenda
Javier Maruenda on 1 Jun 2020
Hello Ryan,
Thank you for your answer. I understand that the training is different in DL and RL, but let me clarify the point.
I trained a network using DL and programatically classified data. The classification is good but may not be the best solution. To find the best solution I settled up an environment for RL where the highest reward would be the best solution.
Using the environment in RL I get the following results
  • Training a new agent from scratch: Low reward (around 3-4 points)
  • Using the pre-trained net as the actor and doing RL training: Again low reward (3-4 points)
  • Using the pre-trained net as the actor without performing any training: High reward (35-45 points)
The fact that skipping the RL training and simulating results in high reward suggests that the network has been imported and is working correctly. The lower reward obtained by RL suggests that the reinforcement learning is not that effective in finding an optimum network (there may be big discontinuities or whatever). However, knowing that the net is correctly imported and working, shouldn't the reinforment learning 'Fine tune' the net which is already delivering good results in the enverinment?
I believe that the network is being re-initialized before RL training hence the reason why the reward is not better than training a new agent from scratch. So the question is if the network is effectively being re-initialized, and how can it be avoided. I tried using 'ResetExperienceBufferBeforeTraining' but it is not available for PPO agents.
Another hypothesis is that the network is not being re-initialized but the learning rate is too high and the training causes jumping to different local minima, but I tried to reduce the learning rate to 1e-6 and it did not make any difference either.
Maybe the solution is changing the type of agent and net to other type that allows importing the net without re-initializing weights.
Just to clarify I am doing sequence to sequence classification with LSTM nets osing the PPO agent.

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!