Agent diverge after converge usingTD3
11 views (last 30 days)
Show older comments
Hello Communiy,
I working on a problem in wirless communication using RL. Some parameters changes across episodes like wireless channel. I'm struggling of choosing a suitable parameters for such problems.
My actor-critic structure:
actorNetwork = [
%featureInputLayer(ObservationInfo.Dimension(2),'Normalization','zscore','Name','observation','Mean',0,'StandardDeviation',1)
featureInputLayer(ObservationInfo.Dimension(2),'Normalization','None','Name','observation')
fullyConnectedLayer(64,'Name','fc1')
layerNormalizationLayer('Name','BN1')
reluLayer('Name','relu1')
fullyConnectedLayer(64,'Name','fc2')
layerNormalizationLayer('Name','BN2')
reluLayer('Name','relu2')
fullyConnectedLayer(64,'Name','fc3')
layerNormalizationLayer('Name','BN3')
reluLayer('Name','relu3')
fullyConnectedLayer(ActionsInfo.Dimension(2),'Name','fc4')
tanhLayer("Name","tanh")];
actorNetwork = layerGraph(actorNetwork);
actorOptions = rlOptimizerOptions('LearnRate',1e-3,'GradientThreshold',inf,'L2RegularizationFactor',1e-4);
actorOptions.OptimizerParameters.GradientDecayFactor = 1e-7 ;
actor = rlDeterministicActorRepresentation(actorNetwork,ObservationInfo,ActionsInfo,...
'Observation',{'observation'},'Action',{'tanh'},actorOptions);
%figure(1)
%plot(actorNetwork)
%% Critic Network
%
% % create a network to be used as underlying critic approximator
% statePath = [featureInputLayer(ObservationInfo.Dimension(2), 'Normalization', 'zscore', 'Name', 'state','Mean',0,'StandardDeviation',1)
% fullyConnectedLayer(32,'Name','fc1')];
% actionPath = [featureInputLayer(ActionsInfo.Dimension(2), 'Normalization', 'zscore', 'Name', 'action','Mean',0,'StandardDeviation',1)
% fullyConnectedLayer(32,'Name','fc2')];
statePath = [featureInputLayer(ObservationInfo.Dimension(2), 'Normalization', 'None', 'Name', 'state')
fullyConnectedLayer(64,'Name','fc1')];
actionPath = [featureInputLayer(ActionsInfo.Dimension(2), 'Normalization', 'None', 'Name', 'action')
fullyConnectedLayer(64, 'Name','fc2')];
commonPath = [concatenationLayer(1,2,'Name','concat')
fullyConnectedLayer(64, 'Name', 'CriticStateFC1')
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(64, 'Name', 'CriticStateFC2')
layerNormalizationLayer('Name','CriticBN1')
reluLayer('Name', 'CriticRelu2')
fullyConnectedLayer(64, 'Name', 'CriticStateFC3')
layerNormalizationLayer('Name','CriticBN2')
reluLayer('Name','CriticRelu3')
fullyConnectedLayer(1,'Name','StateValue')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);
criticNetwork = connectLayers(criticNetwork,'fc1','concat/in1');
criticNetwork = connectLayers(criticNetwork,'fc2','concat/in2');
criticOptions = rlOptimizerOptions('LearnRate',2e-03,'GradientThreshold',inf);
criticOptions.OptimizerParameters.GradientDecayFactor = 1e-7 ;
%critic1 = rlQValueRepresentation(criticNetwork,ObservationInfo,ActionsInfo,...
% 'Observation',{'state'},'Action',{'action'},criticOptions);
%critic2 = rlQValueRepresentation(criticNetwork,ObservationInfo,ActionsInfo,...
% 'Observation',{'state'},'Action',{'action'},criticOptions);
% Need function recommwnded for DDPG and TD3
critic1 = rlQValueFunction(criticNetwork,ObservationInfo,ActionsInfo,...
'Observation',{'state'});
critic2 = rlQValueFunction(criticNetwork,ObservationInfo,ActionsInfo,...
'Observation',{'state'});
%figure(2)
%plot(criticNetwork)
% For PPO
%criticNetwork = dlnetwork(criticNetwork);
%critic1 = rlValueFunction(criticNetwork,ObservationInfo,...
% 'ObservationInputNames',{'state'});
%critic2 = rlValueFunction(criticNetwork,ObservationInfo,...
% 'ObservationInputNames',{'state'});
%%
PPOagentOpts = rlPPOAgentOptions(...
'ExperienceHorizon',600,...
'ClipFactor',0.02,...
'EntropyLossWeight',0.01,...
'ActorOptimizerOptions',actorOptions,...
'CriticOptimizerOptions',criticOptions,...
'NumEpoch',3,...
'AdvantageEstimateMethod','gae',...
'GAEFactor',0.95,...
'DiscountFactor',0.997);
DDPGagentOptions = rlDDPGAgentOptions("DiscountFactor",0.99,...
"ActorOptimizerOptions",actorOptions,...
"TargetUpdateFrequency",5000,...
"CriticOptimizerOptions",criticOptions,...
"MiniBatchSize",16) ;
TD3agentOptions = rlTD3AgentOptions("DiscountFactor", 0.999, ...
"ExplorationModel",rl.option.OrnsteinUhlenbeckActionNoise,...
"PolicyUpdateFrequency" , 100,...
"ExperienceBufferLength",2e6, ...
"MiniBatchSize",16, ...
"NumStepsToLookAhead",1, ...
"ActorOptimizerOptions", actorOptions ,...
"CriticOptimizerOptions" ,criticOptions ,...
"TargetSmoothFactor",0.0005, ...
"TargetUpdateFrequency",10000,...
"SampleTime",1);
SACagentOptions = rlSACAgentOptions("TargetUpdateFrequency",5000,...
"MiniBatchSize",64,...
"ExperienceBufferLength",2e6,...
"PolicyUpdateFrequency",100,...
"CriticUpdateFrequency",50,...,
"ActorOptimizerOptions",actorOptions,...
"CriticOptimizerOptions",criticOptions,...
"DiscountFactor",0.99);
SACagentOptions.EntropyWeightOptions.EntropyWeight = 1000 ;
TD3agentOptions.ExplorationModel.StandardDeviation = [repmat(0.3,1,ActionsInfo.Dimension(2))];
TD3agentOptions.ExplorationModel.StandardDeviationDecayRate = 1e-6;
% agent = rlDDPGAgent(actor,critic1,DDPGagentOptions) ;
%agent = rlTRPOAgent(actor,critic1) ;
agent = rlTD3Agent(actor,[critic1, critic2],TD3agentOptions) ;
% agent = rlSACAgent(actor,[critic1, critic2],SACagentOptions) ;
%agent = rlPPOAgent(actor,critic1,PPOagentOpts) ;
opt = rlTrainingOptions(...
'MaxEpisodes',5000,...
'MaxStepsPerEpisode',1000,...
'StopTrainingCriteria',"AverageReward",...
'StopTrainingValue',480,...
'Verbose', 1,...
'ScoreAveragingWindowLength',40,'UseParallel',false);
opt.ParallelizationOptions.Mode = "sync";
opt.ParallelizationOptions.StepsUntilDataIsSent = 20000;
% opt.ParallelizationOptions.DataToSendFromWorkers = "Gradients";
opt.ParallelizationOptions;
trainResults = train(agent,env,opt);
% delete(gcp('nocreate'))
The agent act with a very strange behivor.
What's wrong?
0 Comments
Answers (1)
Emmanouil Tzorakoleftherakis
on 25 Jan 2023
Please see answer here. This behavior could happen if the optimization moves to explore a different direction. I would suggest stopping training when the average episode reward peaks and see if this agent behaves as expected
0 Comments
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!