Reinforcement Learning Grid World multi-figures
4 views (last 30 days)
Show older comments
Reinforcement Learning
on 14 Feb 2021
Commented: Reinforcement Learning
on 16 Feb 2021
Hello,
I did my own version of Grid World with my own obstacles (see Code below).
My Question ist: How can I simulate the trained agent in the enviroment in multiple figures?
I am using:
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
And getting one variation. I tried using:
for i=1:3
figure(i)
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
end
But it didn't work as planned.
Here my code for that. For some reason, I am getting spikes in the reward plot, although this already converged. I tried to tune some variables like LearnRate, Epsilon and DiscountFactor, but this is the best result I am getting of that:
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
qTable = rlTable(getObservationInfo(env), getActionInfo(env));
qRep = rlQValueRepresentation(qTable, Obs_Info, Act_Info);
%% All trivial until here
qRep.Options.LearnRate = 0.2; % Alpha: This was in the example 1, but it doesn't make sense
Ag_Opts = rlQAgentOptions;
Ag_Opts.DiscountFactor = 0.9; % Gamma
Ag_Opts.EpsilonGreedyExploration.Epsilon = 0.02;
agent = rlQAgent(qRep,Ag_Opts);
Train_Opts = rlTrainingOptions;
Train_Opts.MaxEpisodes = 1000;
Train_Opts.MaxStepsPerEpisode = 40;
Train_Opts.StopTrainingCriteria = "AverageReward";
Train_Opts.StopTrainingValue = 10;
Train_Opts.Verbose = 1;
trainOpts.ScoreAveragingWindowLength = 30;
Train_Opts.Plots = "training-progress";
Train_Info = train(agent,env,Train_Opts);
0 Comments
Accepted Answer
Emmanouil Tzorakoleftherakis
on 16 Feb 2021
Hello,
I wouldn't worry about the spikes as long as the average reward has converged. Could be the agent exploring something.
For your plotting question, the plot function for the gridworld environments has been set up with a listener callback so that it can be updated on the fly every time you call step. This means that you can only have one plot per grid world environment.
A quick workaround would be to create separate environment objects for the same grid world you created and call plot for each one. So:
function env = MyGridWorld
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
end
and then
env1 = MyGridWorld;
env2 = MyGridWorld;
plot(env1)
plot(env2)
Hope that helps
More Answers (0)
See Also
Categories
Find more on Training and Simulation in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!