This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

Train AC Agent to Balance Cart-pole System Using Parallel Computing

This example extends the example Train AC Agent to Balance Cart-Pole System to demonstrate asynchronous parallel training of an Actor-Critic (AC) agent [1] to balance a cart-pole system modeled in MATLAB®.

Actor Parallel Training

When using parallel computing with AC agents, each worker generates experiences from its copy of the agent and the environment. After every N steps, the worker computes gradients from the experiences and sends the computed gradients back to the host agent. The host agent updates its parameters as follows:

  • For asynchronous training, the host agent applies the received gradients and sends the updated parameters back to the worker that provided the gradients. Then, the worker continues to generate experiences from its environment using the updated parameters.

  • For synchronous training, the host agent waits to receive gradients from all of the workers and updates its parameters using these gradients. The host then sends updated parameters to all the workers at the same time. Then, all workers continue to generate experiences using the updated parameters.

Create Cart-pole MATLAB Environment Interface

Create a predefined environment interface for the cart-pole system. For more information on this environment, see Load Predefined Control System Environments.

env = rlPredefinedEnv("CartPole-Discrete");
env.PenaltyForFalling = -10;

Obtain the observation and action information from the environment interface.

obsInfo = getObservationInfo(env);
numObservations = obsInfo.Dimension(1);
actInfo = getActionInfo(env);

Fix the random generator seed for reproducibility.

rng(0)

Create AC agent

An AC agent approximates the long-term reward given observations and actions using a critic value function representation. To create the critic, first create a deep neural network with one input (the observation) and one output (the state value). The input size of the critic network is [4 1 1] since the environment provides 4 observations. For more information on creating a deep neural network value function representation, see Create Policy and Value Function Representations.

criticNetwork = [
    imageInputLayer([4 1 1],'Normalization','none','Name','state')
    fullyConnectedLayer(32,'Name','CriticStateFC1')
    reluLayer('Name', 'CriticRelu1')
    fullyConnectedLayer(1, 'Name', 'CriticFC')];

criticOpts = rlRepresentationOptions('LearnRate',1e-2,'GradientThreshold',1);

critic = rlRepresentation(criticNetwork,criticOpts,'Observation',{'state'},obsInfo);

An AC agent decides which action to take given observations using an actor representation. To create the actor, create a deep neural network with one input (the observation) and one output (the action). The output size of the actor network is 2 since the agent can apply 2 force values to the environment, -10 and 10.

actorNetwork = [
    imageInputLayer([4 1 1],'Normalization','none','Name','state')
    fullyConnectedLayer(32, 'Name','ActorStateFC1')
    reluLayer('Name','ActorRelu1')
    fullyConnectedLayer(2,'Name','action')];

actorOpts = rlRepresentationOptions('LearnRate',1e-2,'GradientThreshold',1);

actor = rlRepresentation(actorNetwork,actorOpts,'Observation',{'state'},obsInfo,'Action',{'action'},actInfo);

To create the AC agent, first specify the AC agent options using rlACAgentOptions.

agentOpts = rlACAgentOptions(...
    'NumStepsToLookAhead',32, ...
    'EntropyLossWeight',0.01, ...
    'DiscountFactor',0.99);

Then, create the agent using the specified actor representation, and the agent options. For more information, see rlACAgent.

agent = rlACAgent(actor,critic,agentOpts);

Parallel Training Options

To train the agent, first specify the training options. For this example, use the following options:

  • Run each training for at most 1000 episodes, with each episode lasting at most 500 time steps.

  • Display the training progress in the Episode Manager dialog box (set the Plots option) and disable the command line display (set the Verbose option).

  • Stop training when the agent receives an average cumulative reward greater than 500 over 10 consecutive episodes. At this point, the agent can balance the pendulum in the upright position.

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode', 500, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',500,...
    'ScoreAveragingWindowLength',10); 

The cart-pole system can be visualized during training or simulation using the plot function.

plot(env)

To train the agent using parallel computing, specify the following training options.

  • Set UseParallel option to True.

  • Train the agent in parallel asynchronously by setting the ParallelizationOptions.Mode option to "async".

  • After every 32 steps, each worker computes gradients from experiences and send them to the host.

  • The AC agent requires workers to send "gradients" to the host.

  • The AC agent requires 'StepsUntilDataIsSent' to be equal to agentOptions.NumStepsToLookAhead.

trainOpts.UseParallel = true;
trainOpts.ParallelizationOptions.Mode = "async";
trainOpts.ParallelizationOptions.DataToSendFromWorkers = "gradients";
trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 32;

For more information, see rlTrainingOptions.

Train Agent

Train the agent using the train function. This is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true. Due to randomness in the asynchronous parallel training, you can expect different training results from the following traing plot. The example is trained with six workers.

doTraining = false;

if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load pretrained agent for the example.
    load('MATLABCartpoleParAC.mat','agent');
end

Simulate AC Agent

The cart-pole system can be visualized with plot(env) during simulation.

plot(env)

To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions and sim.

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward)
totalReward = 500

References

[1] Mnih, V, et al. "Asynchronous methods for deep reinforcement learning," International Conference on Machine Learning, 2016.

MATLAB and Simulink are registered trademarks of The MathWorks, Inc. Please see www.mathworks.com/trademarks for a list of other trademarks owned by The MathWorks, Inc. Other product or brand names are trademarks or registered trademarks of their respective owners.

See Also

|

Related Topics