how to create custom MATLAB Environment from Template
5 views (last 30 days)
Show older comments
I created custom MATLAB Environment from Template for a simple cart-pole model, when I use the validateEnvironment function, it errored that Method, property or field 'Figure' of class 'MyEnvironment' is not recognized
classdef MyEnvironment < rl.env.MATLABEnvironment
%MYENVIRONMENT: Template for defining custom environment in MATLAB.
%% Properties (set properties' attributes accordingly)
properties
% Specify and initialize environment's necessary properties
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
CartMass = 1.0
% Mass of the pole
PoleMass = 0.1
% Half the length of the pole
HalfPoleLength = 0.5
% Max Force the input can apply
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode (radians)
AngleThreshold = 12 * pi/180
% Distance at which to fail the episode
DisplacementThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -10
end
properties
% Initialize system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Initialize internal flag to indicate episode termination
IsDone = false
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
properties (Transient,Access = private)
Visualizer = []
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Necessary Methods
methods
% Contructor method creates an instance of the environment
% Change class name and constructor name accordingly
function this = MyEnvironment()
% Initialize Observation settings
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
% Initialize Action settings
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
% The following line implements built-in functions of RL env
this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
% Initialize property values and pre-compute necessary values
updateActionInfo(this);
end
% Apply system dynamics and simulates the environment with the
% given action for one step.
function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)
LoggedSignals = [];
% Get action
Force = getForce(this,Action);
% Unpack state vector
XDot = this.State(2);
Theta = this.State(3);
ThetaDot = this.State(4);
% Cache to avoid recomputation
CosTheta = cos(Theta);
SinTheta = sin(Theta);
SystemMass = this.CartMass + this.PoleMass;
temp = (Force + this.PoleMass*this.HalfPoleLength * ThetaDot^2 * SinTheta) / SystemMass;
% Apply motion equations
ThetaDotDot = (this.Gravity * SinTheta - CosTheta* temp) / (this.HalfPoleLength * (4.0/3.0 - this.PoleMass * CosTheta * CosTheta / SystemMass));
XDotDot = temp - this.PoleMass*this.HalfPoleLength * ThetaDotDot * CosTheta / SystemMass;
% Euler integration
Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];
% Update system states
this.State = Observation;
% Check terminal condition
X = Observation(1);
Theta = Observation(3);
IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
this.IsDone = IsDone;
% Get reward
Reward = getReward(this);
% (optional) use notifyEnvUpdated to signal that the
% environment has been updated (e.g. to update visualization)
notifyEnvUpdated(this);
end
% Reset environment to initial state and output initial observation
function InitialObservation = reset(this)
% Theta (+- .05 rad)
T0 = 2 * 0.05 * rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
InitialObservation = [T0;Td0;X0;Xd0];
this.State = InitialObservation;
% (optional) use notifyEnvUpdated to signal that the
% environment has been updated (e.g. to update visualization)
notifyEnvUpdated(this);
end
%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
%% Optional Methods (set methods' attributes accordingly)
methods
% Helper methods to create the environment
% Discrete force 1 or 2
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error('Action must be %g for going left and %g for going right.',-this.MaxForce,this.MaxForce);
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 1];
end
% Reward function
function Reward = getReward(this)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
% (optional) Visualization method
function plot(this)
% Initiate the visualization
this.Figure = figure('Visible','on','HandleVisibility','off');
ha = gca(this.Figure);
ha.XLimMode = 'manual';
ha.YLimMode = 'manual';
ha.XLim = [-3 3];
ha.YLim = [-1 2];
hold(ha,'on');
% Update the visualization
envUpdatedCallback(this)
end
% (optional) Properties validation through set methods
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.HalfPoleLength(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','HalfPoleLength');
this.HalfPoleLength = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.CartMass(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','CartMass');
this.CartMass = val;
end
function set.PoleMass(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','PoleMass');
this.PoleMass = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.AngleThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','AngleThreshold');
this.AngleThreshold = val;
end
function set.DisplacementThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','DisplacementThreshold');
this.DisplacementThreshold = val;
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
end
methods (Access = protected)
% (optional) update visualization everytime the environment is updated
% (notifyEnvUpdated is called)
function envUpdatedCallback(this)
if ~isempty(this.Figure) && isvalid(this.Figure)
% Set visualization figure as the current figure
ha = gca(this.Figure);
% Extract the cart position and pole angle
x = this.State(1);
theta = this.State(3);
cartplot = findobj(ha,'Tag','cartplot');
poleplot = findobj(ha,'Tag','poleplot');
if isempty(cartplot) || ~isvalid(cartplot) ...
|| isempty(poleplot) || ~isvalid(poleplot)
% Initialize the cart plot
cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
cartpoly = translate(cartpoly,[x 0]);
cartplot = plot(ha,cartpoly,'FaceColor',[0.8500 0.3250 0.0980]);
cartplot.Tag = 'cartplot';
% Initialize the pole plot
L = this.HalfPoleLength*2;
polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
polepoly = translate(polepoly,[x,0]);
polepoly = rotate(polepoly,rad2deg(theta),[x,0]);
poleplot = plot(ha,polepoly,'FaceColor',[0 0.4470 0.7410]);
poleplot.Tag = 'poleplot';
else
cartpoly = cartplot.Shape;
polepoly = poleplot.Shape;
end
% Compute the new cart and pole position
[cartposx,~] = centroid(cartpoly);
[poleposx,poleposy] = centroid(polepoly);
dx = x - cartposx;
dtheta = theta - atan2(cartposx-poleposx,poleposy-0.25/2);
cartpoly = translate(cartpoly,[dx,0]);
polepoly = translate(polepoly,[dx,0]);
polepoly = rotate(polepoly,rad2deg(dtheta),[x,0.25/2]);
% Update the cart and pole positions on the plot
cartplot.Shape = cartpoly;
poleplot.Shape = polepoly;
% Refresh rendering in the figure window
drawnow();
end
end
end
end
0 Comments
Answers (0)
See Also
Categories
Find more on Environments in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!