Class sdm::QLearning
template <class TInput class TInput>
Class List > sdm > QLearning
Q-Learning and its extensions (DQN, etc).
#include <q_learning.hpp>
Inherits the following classes: sdm::Algorithm
Public Functions
Type | Name |
---|---|
QLearning (std::shared_ptr< GymInterface > & env, std::shared_ptr< ExperienceMemoryInterface > experience_memory, std::shared_ptr< QValueFunction< TInput >> q_value_table, std::shared_ptr< QValueFunction< TInput >> q_value_table_target, std::shared_ptr< QValueBackupInterface > backup, std::shared_ptr< EpsGreedy > exploration, number horizon, double discount=0.9, double lr=0.001, double smooth=0.99, unsigned long num_episodes=10000, std::string name="qlearning") | |
void | do_episode () Execute an episode. |
virtual void | do_initialize () Initialize the algorithm. Initialize the Q-Value function. |
virtual void | do_save () Save the value function. |
virtual void | do_solve () Learning procedure. Will attempt to solve the problem. |
void | do_step () Execute a step. |
virtual void | do_test () Test the result of a problem. |
virtual double | getResult () |
double | getResultOpti () |
virtual int | getTrial () |
void | initLogger () |
std::shared_ptr< Action > | select_action (const std::shared_ptr< Observation > & observation, number t) |
std::shared_ptr< Action > | select_greedy_action (const std::shared_ptr< Observation > & observation, number t) |
void | update_model () Update the q-value functions based on the memory/experience. |
void | update_target () Update the target model. |
Public Functions inherited from sdm::Algorithm
See sdm::Algorithm
Type | Name |
---|---|
virtual void | do_initialize () = 0 Initialize the algorithm. |
virtual void | do_save () = 0 Save the policy in a file. |
virtual void | do_solve () = 0 Solve the problem. |
virtual void | do_test () = 0 Test the result of the algorithm. |
virtual double | getResult () = 0 |
virtual int | getTrial () = 0 |
virtual | ~Algorithm () |
Protected Attributes
Type | Name |
---|---|
double | E_R |
double | R |
std::shared_ptr< QValueBackupInterface > | backup_ |
double | discount_ |
std::shared_ptr< GymInterface > | env_ The problem to be solved. |
unsigned long | episode |
std::shared_ptr< ExperienceMemoryInterface > | experience_memory_ The experience memory. |
std::shared_ptr< EpsGreedy > | exploration_process The exploration process. |
unsigned long | global_step |
number | horizon_ Some hyperparameters for the algorithm. |
std::shared_ptr< MultiLogger > | logger_ The logger. |
double | lr_ |
std::string | name_ = = "qlearning" |
unsigned long | num_episodes_ |
std::shared_ptr< QValueFunction< TInput > > | q_value_table_ Q-value function. |
std::shared_ptr< QValueFunction< TInput > > | q_value_table_target_ Q-value target function. |
std::vector< double > | rewards_ |
double | smooth_ |
number | step |
Public Functions Documentation
function QLearning
sdm::QLearning::QLearning (
std::shared_ptr< GymInterface > & env,
std::shared_ptr< ExperienceMemoryInterface > experience_memory,
std::shared_ptr< QValueFunction < TInput >> q_value_table,
std::shared_ptr< QValueFunction < TInput >> q_value_table_target,
std::shared_ptr< QValueBackupInterface > backup,
std::shared_ptr< EpsGreedy > exploration,
number horizon,
double discount=0.9,
double lr=0.001,
double smooth=0.99,
unsigned long num_episodes=10000,
std::string name="qlearning"
)
function do_episode
void sdm::QLearning::do_episode ()
function do_initialize
virtual void sdm::QLearning::do_initialize ()
Implements sdm::Algorithm::do_initialize
function do_save
virtual void sdm::QLearning::do_save ()
Implements sdm::Algorithm::do_save
function do_solve
virtual void sdm::QLearning::do_solve ()
Implements sdm::Algorithm::do_solve
function do_step
void sdm::QLearning::do_step ()
function do_test
virtual void sdm::QLearning::do_test ()
Implements sdm::Algorithm::do_test
function getResult
inline virtual double sdm::QLearning::getResult ()
Implements sdm::Algorithm::getResult
function getResultOpti
inline double sdm::QLearning::getResultOpti ()
function getTrial
inline virtual int sdm::QLearning::getTrial ()
Implements sdm::Algorithm::getTrial
function initLogger
void sdm::QLearning::initLogger ()
function select_action
std::shared_ptr< Action > sdm::QLearning::select_action (
const std::shared_ptr< Observation > & observation,
number t
)
function select_greedy_action
std::shared_ptr< Action > sdm::QLearning::select_greedy_action (
const std::shared_ptr< Observation > & observation,
number t
)
function update_model
void sdm::QLearning::update_model ()
function update_target
void sdm::QLearning::update_target ()
Protected Attributes Documentation
variable E_R
double sdm::QLearning< TInput >::E_R;
variable R
double sdm::QLearning< TInput >::R;
variable backup_
std::shared_ptr<QValueBackupInterface> sdm::QLearning< TInput >::backup_;
variable discount_
double sdm::QLearning< TInput >::discount_;
variable env_
std::shared_ptr<GymInterface> sdm::QLearning< TInput >::env_;
variable episode
unsigned long sdm::QLearning< TInput >::episode;
variable experience_memory_
std::shared_ptr<ExperienceMemoryInterface> sdm::QLearning< TInput >::experience_memory_;
variable exploration_process
std::shared_ptr<EpsGreedy> sdm::QLearning< TInput >::exploration_process;
variable global_step
unsigned long sdm::QLearning< TInput >::global_step;
variable horizon_
number sdm::QLearning< TInput >::horizon_;
variable logger_
std::shared_ptr<MultiLogger> sdm::QLearning< TInput >::logger_;
variable lr_
double sdm::QLearning< TInput >::lr_;
variable name_
std::string sdm::QLearning< TInput >::name_;
variable num_episodes_
unsigned long sdm::QLearning< TInput >::num_episodes_;
variable q_value_table_
std::shared_ptr<QValueFunction<TInput> > sdm::QLearning< TInput >::q_value_table_;
variable q_value_table_target_
std::shared_ptr<QValueFunction<TInput> > sdm::QLearning< TInput >::q_value_table_target_;
variable rewards_
std::vector<double> sdm::QLearning< TInput >::rewards_;
variable smooth_
double sdm::QLearning< TInput >::smooth_;
variable step
number sdm::QLearning< TInput >::step;
The documentation for this class was generated from the following file src/sdm/algorithms/q_learning.hpp