diff --git a/CMakeLists.txt b/CMakeLists.txt index aa7f544..768fe78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,5 +27,5 @@ MESSAGE(STATUS "Generated global header file") set(INSTALL_DIRECTORY /usr/local/include) install(FILES ${PROJECT_BINARY_DIR}/rl DESTINATION ${INSTALL_DIRECTORY}) -install(DIRECTORY ${PROJECT_BINARY_DIR}/${SUB_HEADERS_PREFIX} +install(DIRECTORY ${PROJECT_BINARY_DIR}/${SUB_HEADERS_PREFIX}/ DESTINATION ${INSTALL_DIRECTORY}/${SUB_HEADERS_PREFIX}) diff --git a/include/agent/ActionSet.h b/include/agent/ActionSet.h index ac30788..8a5f127 100644 --- a/include/agent/ActionSet.h +++ b/include/agent/ActionSet.h @@ -6,6 +6,8 @@ #include +#include "../declares.h" + using namespace std; namespace rl { @@ -27,46 +29,45 @@ class ActionSet { /** * @param actionSet Initial set of action. */ - ActionSet(const set &actionSet); + ActionSet(const spActionSet &actionSet); /** * @return set of actions. */ - const set &getActionSet() const; + const spActionSet &getActionSet() const; /** * @param data A to be added. */ - void addAction(const A &data); + void addAction(const rl::spAction &data); /** * @param dataSet replace the action set with a new one. */ - void setActionSet(set dataSet); + void setActionSet(const spActionSet& dataSet); protected: - set _actionData; - + spActionSet _actionData; }; template ActionSet::ActionSet() {} template -ActionSet::ActionSet(const set &actionSet) : _actionData(actionSet) {} +ActionSet::ActionSet(const spActionSet &actionSet) : _actionData(actionSet) {} template -void ActionSet::addAction(const A &data) { +void ActionSet::addAction(const spAction &data) { _actionData.insert(data); } template -const set &ActionSet::getActionSet() const { +const spActionSet &ActionSet::getActionSet() const { return _actionData; } template -void ActionSet::setActionSet(set dataSet) { +void ActionSet::setActionSet(const spActionSet& dataSet) { _actionData = dataSet; } diff --git a/include/agent/Actuator.h b/include/agent/Actuator.h index f77dd36..ade1db6 100644 --- a/include/agent/Actuator.h +++ b/include/agent/Actuator.h @@ -38,17 +38,17 @@ class Actuator : public ActionSet { * Constructor for when actions (or some actions) are known. * @param actionSet Set of actions. */ - Actuator(const set& actionSet); + Actuator(const spActionSet& actionSet); }; -typedef Actuator ActuatorSL; +typedef Actuator ActuatorSL; template -rl::agent::Actuator::Actuator() { +Actuator::Actuator() { } template -rl::agent::Actuator::Actuator(const set& actionSet) : ActionSet(actionSet) { +Actuator::Actuator(const spActionSet& actionSet) : ActionSet(actionSet) { } } // namespace agent diff --git a/include/agent/Agent.h b/include/agent/Agent.h index 488d759..7357a55 100644 --- a/include/agent/Agent.h +++ b/include/agent/Agent.h @@ -53,7 +53,10 @@ class AgentSupervised { * @param reward * @param nextState */ - virtual void train(const S& state, const A& action, FLOAT reward, const S& nextState) { + virtual void train(const spState& state, + const spAction& action, + FLOAT reward, + const spState& nextState) { this->_learningAlgorithm.update(StateAction(state, action), nextState, reward, @@ -92,7 +95,10 @@ class Agent { * @param reward * @param nextState */ - virtual void train(const S& state, const A& action, FLOAT reward, const S& nextState); + virtual void train(const spState& state, + const spAction& action, + FLOAT reward, + const spState& nextState); /** * Prepare agent prior to start execution. @@ -131,12 +137,12 @@ class Agent { /** * @param action Agent applies action to the environment. */ - virtual void applyAction(const A& action); + virtual void applyAction(const spAction& action); /** * @return CurrentState of the agent. */ - virtual S getLastObservedState() const; + virtual spState getLastObservedState() const; /** * Calls the Environment::reset @@ -151,8 +157,8 @@ class Agent { Environment& _environment; // !< Aggregate environment obj. - S _currentState; //!< Keeps track of the current state. - A _currentAction; //!< Keeps track of the current action. + spState _currentState; //!< Keeps track of the current state. + spAction _currentAction; //!< Keeps track of the current action. FLOAT _accumulativeReward; //!< Keeps track of accumulation of reward during //!< the span of the episode.Specifically, after @@ -173,7 +179,7 @@ template using AgentSL = Agent, vector>; template -rl::agent::Agent::Agent(Environment& environment, +Agent::Agent(Environment& environment, algorithm::LearningAlgorithm& learningAlgorithm) : _environment(environment), _learningAlgorithm(learningAlgorithm), @@ -184,12 +190,16 @@ rl::agent::Agent::Agent(Environment& environment, } template -S rl::agent::Agent::getLastObservedState() const { +spState Agent::getLastObservedState() const { return _environment.getSensor().getLastObservedState(); } template -void rl::agent::Agent::train(const S& state, const A& action, FLOAT reward, const S& nextState) { +void Agent::train( + const spState& state, + const spAction& action, + FLOAT reward, + const spState& nextState) { this->_learningAlgorithm.update( StateAction(state, action), nextState, @@ -198,7 +208,7 @@ void rl::agent::Agent::train(const S& state, const A& action, FLOAT reward } template -void rl::agent::Agent::preExecute() { +void Agent::preExecute() { _currentState = std::move(getLastObservedState()); _currentAction = std::move(_learningAlgorithm.getAction( _currentState, _environment.getActuator().getActionSet())); @@ -207,10 +217,10 @@ void rl::agent::Agent::preExecute() { } template -void rl::agent::Agent::execute() { +void Agent::execute() { // todo: Acquire last state and reward here. this->applyAction(_currentAction); - S nextState = std::move(getLastObservedState()); + spState nextState = std::move(getLastObservedState()); FLOAT reward = this->_environment.getSensor().getLastObservedReward(); // Accumulate reward. @@ -229,7 +239,7 @@ void rl::agent::Agent::execute() { } template -size_t rl::agent::Agent::executeEpisode(UINT maxIter){ +size_t Agent::executeEpisode(UINT maxIter){ preExecute(); UINT i = 0; for(; i < maxIter && episodeDone() == false; i++) { @@ -240,27 +250,27 @@ size_t rl::agent::Agent::executeEpisode(UINT maxIter){ } template -bool rl::agent::Agent::episodeDone() { +bool Agent::episodeDone() { return _environment.getSensor().isTerminalState(_currentState); } template -rl::FLOAT rl::agent::Agent::postExecute() { +rl::FLOAT Agent::postExecute() { return _accumulativeReward; } template -inline rl::FLOAT rl::agent::Agent::getAccumulativeReward() const { +inline rl::FLOAT Agent::getAccumulativeReward() const { return _accumulativeReward; } template -void rl::agent::Agent::applyAction(const A& action) { +void Agent::applyAction(const spAction& action) { this->_environment.applyAction(action); } template -void rl::agent::Agent::reset() { +void Agent::reset() { this->_environment.reset(); }; diff --git a/include/agent/Environment.h b/include/agent/Environment.h index 0ad72df..2821735 100644 --- a/include/agent/Environment.h +++ b/include/agent/Environment.h @@ -34,13 +34,13 @@ class Environment { * @param stateAction Given this state-action, gives next state and reward. * @return Next state and reward. */ - virtual std::pair getNextStateAndReward(const SA &stateAction) = 0; + virtual spStateAndReward getNextStateAndReward(const SA &stateAction) = 0; /** * @see Actuator Documentation for example. * @param action to be applied to environment. */ - virtual StateAndReward applyAction(const A &action); + virtual spStateAndReward applyAction(const spAction &action); Actuator &getActuator(); const Actuator &getActuator() const; @@ -65,10 +65,10 @@ inline Environment::Environment(Actuator &actuator, Sensor &sensor) } template -StateAndReward Environment::applyAction(const A &action) { - auto currentState = this->_sensor.getLastObservedState(); - auto currentStateAction = StateAction(currentState, action); - auto nextStateAndReward = this->getNextStateAndReward(currentStateAction); +spStateAndReward Environment::applyAction(const spAction &action) { + spState currentState = this->_sensor.getLastObservedState(); + StateAction currentStateAction (currentState, action); + spStateAndReward nextStateAndReward = this->getNextStateAndReward(currentStateAction); this->_sensor.setLastObservedStateAndReward(nextStateAndReward); return nextStateAndReward; diff --git a/include/agent/Sensor.h b/include/agent/Sensor.h index 5b99165..86d1bfb 100644 --- a/include/agent/Sensor.h +++ b/include/agent/Sensor.h @@ -32,12 +32,12 @@ namespace agent { template class Sensor { public: - Sensor(const S &initialState); + Sensor(const spState &initialState); /** * @return current state of agent in environment. */ - virtual const S &getLastObservedState() const; + virtual const spState &getLastObservedState() const; /** * Maps sensorState to its corresponding reward. @@ -50,13 +50,13 @@ class Sensor { * Set the last observed state. * @param s Last observed state. */ - virtual void setLastObservedState(const S &s); + virtual void setLastObservedState(const spState &s); /** * Changes initial state. * @param s New initial state. */ - virtual void setInitialState(const S &s); + virtual void setInitialState(const spState &s); /** * Set the last observed reward @@ -69,7 +69,7 @@ class Sensor { * @param stateAndReward pair of state and reward. */ virtual void setLastObservedStateAndReward( - const StateAndReward &stateAndReward); + const spStateAndReward &stateAndReward); /** * Resets last observed state to initial state. @@ -80,22 +80,22 @@ class Sensor { * @param stateData to determine if it is a terminal state. * @return true if its a terminal state. */ - virtual bool isTerminalState(const S &stateData) const = 0; + virtual bool isTerminalState(const spState &stateData) const = 0; private: FLOAT _lastObservedReward = NAN; - S _initialState; - S _lastObservedState; + spState _initialState; + spState _lastObservedState; }; template -Sensor::Sensor(const S &initialState) : +Sensor::Sensor(const spState &initialState) : _initialState(initialState), _lastObservedState(_initialState) { } template -const S &Sensor::getLastObservedState() const { +const spState &Sensor::getLastObservedState() const { return this->_lastObservedState; } @@ -105,12 +105,12 @@ rl::FLOAT Sensor::getLastObservedReward() const { } template -void Sensor::setLastObservedState(const S &s) { +void Sensor::setLastObservedState(const spState &s) { this->_lastObservedState = s; } template -void Sensor::setInitialState(const S &s) { +void Sensor::setInitialState(const spState &s) { this->_initialState = s; } @@ -121,7 +121,7 @@ void Sensor::setLastObservedReward(FLOAT r) { template void Sensor::setLastObservedStateAndReward( - const StateAndReward &stateAndReward) { + const spStateAndReward &stateAndReward) { this->setLastObservedState(std::get<0>(stateAndReward)); this->setLastObservedReward(std::get<1>(stateAndReward)); } diff --git a/include/agent/SensorDiscrete.h b/include/agent/SensorDiscrete.h index 550bb90..eeb1841 100644 --- a/include/agent/SensorDiscrete.h +++ b/include/agent/SensorDiscrete.h @@ -27,26 +27,26 @@ class SensorDiscrete : public Sensor { public: using Sensor::Sensor; - virtual bool isTerminalState(const S &stateData) const override; + virtual bool isTerminalState(const spState &stateData) const override; /** * @param terminalData new terminal state to be added. */ - virtual void addTerminalState(const S &terminalData); + virtual void addTerminalState(const spState &terminalData); private: - set _terminalStates; // Must know when to stop. + spStateSet _terminalStates; // Must know when to stop. }; template bool SensorDiscrete::isTerminalState( - const S &stateData) const { + const spState &stateData) const { return _terminalStates.find(stateData) != _terminalStates.end(); } template void SensorDiscrete::addTerminalState( - const S &terminalData) { + const spState &terminalData) { _terminalStates.insert(terminalData); } diff --git a/include/agent/StateAction.h b/include/agent/StateAction.h index 7439370..177554f 100644 --- a/include/agent/StateAction.h +++ b/include/agent/StateAction.h @@ -5,8 +5,9 @@ * Author: jandres */ -#ifndef STATEACTION_H_ -#define STATEACTION_H_ +#pragma once + +#include "../declares.h" namespace rl { namespace agent { @@ -32,7 +33,7 @@ class StateAction { * @param state of state-action pair. * @param action of state-action pair. */ - StateAction(S state, A action); + explicit StateAction(const spState& state, const spState& action); /** * Copy Constructor. @@ -50,52 +51,55 @@ class StateAction { /** * @return return state of state-action pair. */ - const S &getState() const; + const spState &getState() const; /** * @return return action of state-action pair. */ - const A &getAction() const; + const spAction &getAction() const; /** * @param state set the state of state-action pair. */ - void setState(const S &state); + void setState(const spState &state); /** * @param action set the action of state-action pair. */ - void setAction(const A &action); + void setAction(const spAction &action); protected: - S _state; //!< State of state-action pair. - A _action; //!< Action of state-action pair. + spState _state; //!< State of state-action pair. + spAction _action; //!< Action of state-action pair. }; template -StateAction::StateAction(S state, A action) : _state(state), _action(action) {} +StateAction::StateAction(const spState& state, const spAction& action) : + _state(state), + _action(action) { +} template StateAction::StateAction(const StateAction &sa) : _state(sa._state), _action(sa._action) {} template bool StateAction::operator<(const StateAction &stateAction) const { - if (_state < stateAction._state) + if (*_state < *(stateAction._state)) return true; - if (_state > stateAction._state) + if (*_state > *(stateAction._state)) return false; - if (_action < stateAction._action) + if (*_action < *(stateAction._action)) return true; return false; } template bool StateAction::operator>(const StateAction &stateAction) const { - if (_state > stateAction._state) + if (*_state > *(stateAction._state)) return true; - if (_state < stateAction._state) + if (*_state < *(stateAction._state)) return false; - if (_action > stateAction._action) + if (*_action > *(stateAction._action)) return true; return false; } @@ -112,7 +116,7 @@ bool StateAction::operator>=(const StateAction &stateAction) const { template bool StateAction::operator==(const StateAction &stateAction) const { - if (_state == stateAction._state && _action == stateAction._action) { + if (*_state == *(stateAction._state) && *_action == *(stateAction._action)) { return true; } return false; @@ -127,26 +131,24 @@ bool StateAction::operator!=(const StateAction &stateAction) const { } template -const S &StateAction::getState() const { +const spState &StateAction::getState() const { return this->_state; } template -const A &StateAction::getAction() const { +const spAction &StateAction::getAction() const { return this->_action; } template -void StateAction::setState(const S &state) { +void StateAction::setState(const spState &state) { _state = state; } template -void StateAction::setAction(const A &action) { +void StateAction::setAction(const spAction &action) { _action = action; } } // namespace agent } /* namespace rl */ - -#endif /* STATEACTION_H_ */ diff --git a/include/agent/StateActionPairContainer.h b/include/agent/StateActionPairContainer.h index 7e19806..7869e24 100644 --- a/include/agent/StateActionPairContainer.h +++ b/include/agent/StateActionPairContainer.h @@ -14,7 +14,7 @@ #include #include #include -#include +#include #include "StateAction.h" #include "StateActionNotExistException.h" @@ -44,8 +44,8 @@ class StateActionPairContainer { * @param value Value of the state to be added. * @param actionArray */ - virtual void addState(const S &state, rl::FLOAT value, - const set &actionSet); + virtual void addState(const spState &state, rl::FLOAT value, + const spActionSet &actionSet); /** * Adds a new state with the corresponding action. @@ -60,14 +60,14 @@ class StateActionPairContainer { * @param value * @param actionSet */ - virtual void addAction(const A &action, rl::FLOAT value, const set &stateSet); + virtual void addAction(const spAction &action, rl::FLOAT value, const spStateSet &stateSet); /** * @param state to be search in the _stateActionPairMap. * @return true if state is in _stateActionPairMap. */ - virtual bool stateInStateActionPairMap(const S &state, - const set &actionSet) const; + virtual bool stateInStateActionPairMap(const spState &state, + const spActionSet &actionSet) const; /** * @param stateAction @@ -125,9 +125,9 @@ StateActionPairContainer::StateActionPairContainer() { } template -void StateActionPairContainer::addState(const S &state, rl::FLOAT value, - const set &actionSet) { - for (const A &action : actionSet) { +void StateActionPairContainer::addState(const spState &state, rl::FLOAT value, + const spActionSet &actionSet) { + for (auto action : actionSet) { _stateActionPairMap.insert( std::pair, rl::FLOAT>( StateAction(state, action), value)); @@ -142,9 +142,9 @@ void StateActionPairContainer::addStateAction(const StateAction &sta }; template -void StateActionPairContainer::addAction(const A &action, rl::FLOAT value, - const set &stateSet) { - for (const S &state : stateSet) { +void StateActionPairContainer::addAction(const spAction &action, rl::FLOAT value, + const spStateSet &stateSet) { + for (auto state : stateSet) { _stateActionPairMap.insert( std::pair, rl::FLOAT>( StateAction(state, action), value)); @@ -153,8 +153,8 @@ void StateActionPairContainer::addAction(const A &action, rl::FLOAT value, template bool StateActionPairContainer::stateInStateActionPairMap( - const S &state, const set &actionSet) const { - const A &sampleAction = *(actionSet.begin()); + const spState &state, const spActionSet &actionSet) const { + const spAction &sampleAction = *(actionSet.begin()); bool rv = _stateActionPairMap.find(StateAction(state, sampleAction)) != _stateActionPairMap.end(); return rv; @@ -167,6 +167,7 @@ throw(StateActionNotExistException) { try { _stateActionPairMap.at(stateAction); } catch (const std::out_of_range &oor) { + cerr << "State-Pair given is not yet added. " << __FILE__ ":" << __LINE__ << std::endl; StateActionNotExistException exception("State-Pair given is not yet added."); throw exception; } @@ -181,6 +182,7 @@ throw(StateActionNotExistException) { try { return _stateActionPairMap.at(stateAction); } catch (const std::out_of_range &oor) { + cerr << "State-Pair given is not yet added. " << __FILE__ ":" << __LINE__ << std::endl; StateActionNotExistException exception("State-Pair given is not yet added."); throw exception; } @@ -221,4 +223,3 @@ inline multimap> rl::agent::StateActionP } // namespace rl #endif /* STATEACTIONPAIRCONTAINER_H */ - diff --git a/include/agent/StateActionTransition.h b/include/agent/StateActionTransition.h index a78aa68..b53a333 100644 --- a/include/agent/StateActionTransition.h +++ b/include/agent/StateActionTransition.h @@ -8,15 +8,15 @@ #ifndef STATEACTIONTRANSITION_H #define STATEACTIONTRANSITION_H -#include "../declares.h" - #include #include #include #include #include #include +#include +#include "../declares.h" #include "StateActionTransitionException.h" #include "../algorithm/QLearning.h" @@ -62,7 +62,7 @@ class StateActionTransition { * @param nextState to be added or increased frequency value and update reward value. * @param reward to update the value of the nextState. */ - virtual void update(const S& nextState, const rl::FLOAT reward); + virtual void update(const spState& nextState, const rl::FLOAT reward); /** * Given a state returns its latest reward info. @@ -70,7 +70,7 @@ class StateActionTransition { * @return reward of the state. * @throw StateActionTransitionException when given state don't exist. */ - virtual rl::FLOAT getReward(const S& state) const + virtual rl::FLOAT getReward(const spState& state) const throw (StateActionTransitionException); /** @@ -78,7 +78,7 @@ class StateActionTransition { * chance of occuring. * @throw StateActionTransitionException when given state don't exist. */ - virtual const S& getNextState() const throw (StateActionTransitionException); + virtual const spState& getNextState() const throw (StateActionTransitionException); /** * @return the number of transition states. @@ -106,13 +106,13 @@ class StateActionTransition { rl::FLOAT getGreedy() const; private: - bool _findState(const S& state) const; + bool _findState(const spState& state) const; private: // Keeps track of all the possible transistion states and their // corresponding frequency and reward. - map _stateActionTransitionFrequency; - map _stateActionTransitionReward; + spStateXMap _stateActionTransitionFrequency; + spStateXMap _stateActionTransitionReward; rl::FLOAT _greedy; rl::FLOAT _stepSize; @@ -137,18 +137,18 @@ StateActionTransition::StateActionTransition( } template -void StateActionTransition::update(const S& nextState, +void StateActionTransition::update(const spState& nextState, const rl::FLOAT reward) { _stateActionTransitionFrequency.insert( - std::pair(nextState, 0.0F)); + spStateAndReward(nextState, 0.0F)); _stateActionTransitionReward.insert( - std::pair(nextState, reward)); + spStateAndReward(nextState, reward)); // Update Frequency. // --Lower the value of all other frequencies. for (auto iter = _stateActionTransitionFrequency.begin(); - iter != _stateActionTransitionFrequency.end(); iter++) { - const S& state = iter->first; + iter != _stateActionTransitionFrequency.end(); iter++) { + auto state = iter->first; if (state != nextState) { _stateActionTransitionFrequency[state] = _stateActionTransitionFrequency[state] @@ -172,7 +172,7 @@ void StateActionTransition::update(const S& nextState, } template -rl::FLOAT StateActionTransition::getReward(const S& state) const +rl::FLOAT StateActionTransition::getReward(const spState& state) const throw (StateActionTransitionException) { StateActionTransitionException exception( "StateActionTransition::getReward(const S& state): state not yet added."); @@ -184,14 +184,14 @@ rl::FLOAT StateActionTransition::getReward(const S& state) const } template -bool StateActionTransition::_findState(const S& state) const { +bool StateActionTransition::_findState(const spState& state) const { bool found = _stateActionTransitionFrequency.find(state) != _stateActionTransitionFrequency.end(); return found; } template -const S& StateActionTransition::getNextState() const +const spState& StateActionTransition::getNextState() const throw (StateActionTransitionException) { StateActionTransitionException exception( "StateActionTransition::getNextState(): nextStates are empty."); @@ -209,8 +209,7 @@ const S& StateActionTransition::getNextState() const if (randomNumberRandomSelection > _greedy) { auto it = _stateActionTransitionFrequency.begin(); std::advance(it, _randomDevice() % _stateActionTransitionFrequency.size()); - const S& nextState = it->first; - return nextState; + return it->first; } // http://stackoverflow.com/questions/1761626/weighted-random-numbers diff --git a/include/agent/StateInterface.h b/include/agent/StateInterface.h index a75cc3b..d50aff8 100644 --- a/include/agent/StateInterface.h +++ b/include/agent/StateInterface.h @@ -17,7 +17,7 @@ namespace agent { template class StateInterface { public: - StateInterface(const S &value) : _value(value) {} + StateInterface(const spState &value) : _value(value) {} virtual bool operator<(const StateInterface &rhs) const { return this->_value < rhs._value; @@ -52,16 +52,16 @@ class StateInterface { return *this; } - virtual S &getValue() { + virtual spState &getValue() { return this->_value; } - virtual const S &getValue() const { + virtual const spState &getValue() const { return this->_value; } protected: - S _value; // Value of this thing. + spState _value; // Value of this thing. }; } // namespace agent diff --git a/include/algorithm/DynaQ.h b/include/algorithm/DynaQ.h index 52e9934..031e478 100644 --- a/include/algorithm/DynaQ.h +++ b/include/algorithm/DynaQ.h @@ -50,9 +50,9 @@ class DynaQ : public DynaQRLMP { rl::FLOAT stateTransitionGreediness, rl::FLOAT stateTransitionStepSize); virtual void update(const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const FLOAT reward, - const set& actionSet) override; + const spActionSet& actionSet) override; }; template @@ -68,10 +68,10 @@ DynaQ::DynaQ(rl::FLOAT stepSize, rl::FLOAT discountRate, template void rl::algorithm::DynaQ::update( const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const FLOAT reward, - const set& actionSet) { - A nextAction = this->getLearningAction(nextState, actionSet); + const spActionSet& actionSet) { + spState nextAction = this->getLearningAction(nextState, actionSet); DynaQRLMP::updateStateAction( currentStateAction, StateAction(nextState, nextAction), diff --git a/include/algorithm/DynaQBase.h b/include/algorithm/DynaQBase.h index b8545b0..d0a5e16 100644 --- a/include/algorithm/DynaQBase.h +++ b/include/algorithm/DynaQBase.h @@ -64,7 +64,7 @@ class DynaQBase { * @param actionSet a set of possible actions. * @return the action that will "likely" gives the highest reward. */ - virtual A argMax(const S& state, const set& actionSet) const = 0; + virtual spAction argMax(const spState& state, const spActionSet& actionSet) const = 0; /** * Update the stateAction map. @@ -75,7 +75,7 @@ class DynaQBase { * @param actionSet A set of all actions. */ virtual void _updateModel(const StateAction& currentStateAction, - const S& nextState, const FLOAT reward); + const spState& nextState, const FLOAT reward); /** * Adds new stateAction pair to the model with mutex lock. @@ -93,7 +93,7 @@ class DynaQBase { * Performs simulation _simulationIterationCount times. * @param actionSet set of actions of agent. */ - virtual void _simulate(const set& actionSet); + virtual void _simulate(const spActionSet& actionSet); protected: rl::UINT _simulationIterationCount; //!< Number of simulation, the higher the value @@ -125,7 +125,7 @@ inline DynaQBase::DynaQBase( template void DynaQBase::_updateModel( - const StateAction& currentStateAction, const S& nextState, + const StateAction& currentStateAction, const spState& nextState, const FLOAT reward) { _addModel(currentStateAction); @@ -151,7 +151,7 @@ inline void DynaQBase::_addModel( } template -void DynaQBase::_simulate(const set& actionSet) { +void DynaQBase::_simulate(const spActionSet& actionSet) { if (_model.empty()) return; @@ -163,9 +163,9 @@ void DynaQBase::_simulate(const set& actionSet) { const StateActionTransition& st = _model.at(item->first); - const S& transState = st.getNextState(); + const spState& transState = st.getNextState(); - A nextAction = argMax(transState, actionSet); + spAction nextAction = argMax(transState, actionSet); backUpStateActionPair(item->first, st.getReward(transState), StateAction(transState, nextAction)); diff --git a/include/algorithm/DynaQET.h b/include/algorithm/DynaQET.h index cb62f38..67a8db7 100644 --- a/include/algorithm/DynaQET.h +++ b/include/algorithm/DynaQET.h @@ -48,17 +48,19 @@ class DynaQET final: public DynaQ, public EligibilityTraces { * Small \f$\lambda\f$ converges to TD(0). */ DynaQET(rl::FLOAT stepSize, rl::FLOAT discountRate, - policy::Policy& policy, rl::UINT simulationIterationCount, + policy::Policy& policy, + rl::UINT simulationIterationCount, rl::FLOAT stateTransitionGreediness, - rl::FLOAT stateTransitionStepSize, rl::FLOAT lambda); + rl::FLOAT stateTransitionStepSize, + rl::FLOAT lambda); public: // Inherited. virtual void update(const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const rl::FLOAT currentStateActionValue, - const set& actionSet) override; + const spActionSet& actionSet) override; }; template @@ -74,9 +76,9 @@ DynaQET::DynaQET(rl::FLOAT stepSize, rl::FLOAT discountRate, template void DynaQET::update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) { - A nextAction = this->getLearningAction(nextState, actionSet); + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) { + spAction nextAction = this->getLearningAction(nextState, actionSet); // For some reason I need to call the grand parent class ReinforcementLearning. DynaQ::updateStateAction( diff --git a/include/algorithm/DynaQPrioritizedSweeping.h b/include/algorithm/DynaQPrioritizedSweeping.h index f389ed3..a621554 100644 --- a/include/algorithm/DynaQPrioritizedSweeping.h +++ b/include/algorithm/DynaQPrioritizedSweeping.h @@ -63,11 +63,11 @@ class DynaQPrioritizeSweeping final: public DynaQ { rl::FLOAT priorityThreshold); virtual void update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) override; + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) override; protected: - void _prioritySweep(const set& actionSet); + void _prioritySweep(const spActionSet& actionSet); rl::FLOAT _getTDError(const StateAction& currentStateAction, const rl::FLOAT reward, const StateAction& nextStateAction); @@ -88,7 +88,7 @@ DynaQPrioritizeSweeping::DynaQPrioritizeSweeping( } template -void DynaQPrioritizeSweeping::_prioritySweep(const set& actionSet) { +void DynaQPrioritizeSweeping::_prioritySweep(const spActionSet& actionSet) { // Repeat n times while priority queue is not empty. for (rl::UINT i = 0; i < this->_simulationIterationCount; i++) { if (_priorityQueue.empty()) @@ -100,11 +100,11 @@ void DynaQPrioritizeSweeping::_prioritySweep(const set& actionSet) { // Acquire next reward from model. const StateActionTransition& currentStateActionTransition = this->_model .at(currentStateActionPair); - const S& nextState = currentStateActionTransition.getNextState(); + auto nextState = currentStateActionTransition.getNextState(); rl::FLOAT nextReward = currentStateActionTransition.getReward(nextState); // Acquire max_action(Q(S, a)). - A nextAction = this->getLearningAction(nextState, actionSet); + auto nextAction = this->getLearningAction(nextState, actionSet); StateAction nextStateAction(nextState, nextAction); @@ -116,13 +116,13 @@ void DynaQPrioritizeSweeping::_prioritySweep(const set& actionSet) { this->_model.begin(); iter != this->_model.end(); iter++) { // Acquire a model(S) and look for a model(S) that transition to current(S, A). StateActionTransition& modelStateTransition = iter->second; - const S& modelNextState = modelStateTransition.getNextState(); + auto modelNextState = modelStateTransition.getNextState(); // If model(S) -> current(S, A), that means model(S) factors in reaching terminal state. // And because of that, back up model(S). if (modelNextState == currentStateActionPair.getState()) { rl::FLOAT priorReward = modelStateTransition.getReward(modelNextState); - A nextModelAction = this->argMax(modelNextState, actionSet); + auto nextModelAction = this->argMax(modelNextState, actionSet); rl::FLOAT temptdError = _getTDError( iter->first, priorReward, @@ -150,11 +150,11 @@ rl::FLOAT DynaQPrioritizeSweeping::_getTDError( template void DynaQPrioritizeSweeping::update( const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const rl::FLOAT reward, - const set& actionSet) { - A nextAction = this->argMax(nextState, actionSet); - auto nextStateAction = StateAction(nextState, nextAction); + const spActionSet& actionSet) { + spAction nextAction = this->argMax(nextState, actionSet); + StateAction nextStateAction(nextState, nextAction); DynaQ::update(currentStateAction, nextState, reward, actionSet); diff --git a/include/algorithm/DynaQRLMP.h b/include/algorithm/DynaQRLMP.h index 878cad4..4730aa6 100644 --- a/include/algorithm/DynaQRLMP.h +++ b/include/algorithm/DynaQRLMP.h @@ -58,8 +58,8 @@ class DynaQRLMP : public ReinforcementLearning, public DynaQBase { * @param nextStateActionPair \f$(S', A')\f$, next state-action pair. */ virtual void backUpStateActionPair( - const StateAction& currentStateAction, const rl::FLOAT reward, - const StateAction& nextStateActionPair); + const StateAction& currentStateAction, const rl::FLOAT reward, + const StateAction& nextStateActionPair); /** * Returns the action that will most "likely" gives the highest reward from the @@ -69,30 +69,30 @@ class DynaQRLMP : public ReinforcementLearning, public DynaQBase { * @param actionSet a set of possible actions. * @return the action that will "likely" gives the highest reward. */ - virtual A argMax(const S& state, const set& actionSet) const; + virtual spAction argMax(const spState& state, const spActionSet& actionSet) const; }; template inline DynaQRLMP::DynaQRLMP( - rl::FLOAT stepSize, rl::FLOAT discountRate, policy::Policy& policy, - rl::UINT simulationIterationCount, rl::FLOAT stateTransitionGreediness, - rl::FLOAT stateTransitionStepSize) - : ReinforcementLearning(stepSize, discountRate, policy), - DynaQBase(simulationIterationCount, stateTransitionGreediness, - stateTransitionStepSize) { + rl::FLOAT stepSize, rl::FLOAT discountRate, policy::Policy& policy, + rl::UINT simulationIterationCount, rl::FLOAT stateTransitionGreediness, + rl::FLOAT stateTransitionStepSize) + : ReinforcementLearning(stepSize, discountRate, policy), + DynaQBase(simulationIterationCount, stateTransitionGreediness, + stateTransitionStepSize) { } template inline void DynaQRLMP::backUpStateActionPair( - const StateAction& currentStateAction, const rl::FLOAT reward, - const StateAction& nextStateActionPair) { + const StateAction& currentStateAction, const rl::FLOAT reward, + const StateAction& nextStateActionPair) { ReinforcementLearning::backUpStateActionPair(currentStateAction, reward, nextStateActionPair); } template -inline A DynaQRLMP::argMax(const S& state, - const set& actionSet) const { +inline spAction DynaQRLMP::argMax(const spState& state, + const spActionSet& actionSet) const { return ReinforcementLearning::argMax(state, actionSet); } diff --git a/include/algorithm/EligibilityTraces.h b/include/algorithm/EligibilityTraces.h index ae86484..e03e77a 100644 --- a/include/algorithm/EligibilityTraces.h +++ b/include/algorithm/EligibilityTraces.h @@ -82,7 +82,7 @@ class EligibilityTraces { protected: rl::FLOAT _lambda; - map, rl::FLOAT> _eligibilityTraces; + map, rl::FLOAT> _eligibilityTraces; }; template @@ -116,8 +116,8 @@ void EligibilityTraces::_updateEligibilityTraces( for (auto iter = stateActionPairValue.begin(); iter != stateActionPairValue.end(); iter++) { - const S& state = iter->first.getState(); - const A& action = iter->first.getAction(); + auto state = iter->first.getState(); + auto action = iter->first.getAction(); rl::FLOAT value = iter->second; stateActionPairValue.setStateActionValue( diff --git a/include/algorithm/LearningAlgorithm.h b/include/algorithm/LearningAlgorithm.h index 2f2dcd0..e821cee 100644 --- a/include/algorithm/LearningAlgorithm.h +++ b/include/algorithm/LearningAlgorithm.h @@ -46,16 +46,16 @@ class LearningAlgorithm { * @param reward Reward for the transition from current-state action to next state-action. */ virtual void update(const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const rl::FLOAT reward, - const set& actionSet) = 0; + const spActionSet& actionSet) = 0; /** * @param state the state to take the action to. * @param actionSet set of possible actions. * @return action based on control policy and current state. */ - virtual A getAction(const S& state, const set& actionSet) = 0; + virtual spAction getAction(const spState& state, const spActionSet& actionSet) = 0; /** * @param stateAction @@ -116,10 +116,11 @@ class LearningAlgorithm { * @param actionSet current action set. * @return action selected by learning policy. */ - A _getLearningPolicyAction(const map& actionValueMap, - const set& actionSet); - A _getLearningPolicyAction(const map& actionValueMap, - const set& actionSet, ACTION_CONT& action); + spAction _getLearningPolicyAction(const spActionValueMap& actionValueMap, + const spActionSet& actionSet); + spAction _getLearningPolicyAction(const spActionValueMap& actionValueMap, + const spActionSet& actionSet, + const spAction& action); protected: rl::FLOAT _defaultStateActionValue; //!< Place holder for default state action value. @@ -179,15 +180,17 @@ inline const policy::Policy& rl::algorithm::LearningAlgorithm< } template -inline A LearningAlgorithm::_getLearningPolicyAction( - const map& actionValueMap, const set& actionSet) { +inline spAction LearningAlgorithm::_getLearningPolicyAction( + const spActionValueMap& actionValueMap, + const spActionSet& actionSet) { return _learningPolicy->getAction(actionValueMap, actionSet); } template -inline A LearningAlgorithm::_getLearningPolicyAction( - const map& actionValueMap, const set& actionSet, - ACTION_CONT& action) { +inline spAction LearningAlgorithm::_getLearningPolicyAction( + const spActionValueMap& actionValueMap, + const spActionSet& actionSet, + const spAction& action) { return _learningPolicy->getAction(actionValueMap, actionSet, action); } diff --git a/include/algorithm/QLearning.h b/include/algorithm/QLearning.h index d0bcbf8..12431ca 100644 --- a/include/algorithm/QLearning.h +++ b/include/algorithm/QLearning.h @@ -43,8 +43,8 @@ class QLearning : public ReinforcementLearning { policy::Policy& policy); virtual void update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) override; + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) override; }; template @@ -55,11 +55,12 @@ QLearning::QLearning(rl::FLOAT stepSize, rl::FLOAT discountRate, template void QLearning::update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) { + const spState& nextState, + const rl::FLOAT reward, + const spActionSet& actionSet) { // Note: this algorithm is in pg. 145 of Sutton Barto 2nd edition. // Q(S, A) <- Q(S, A) + α[ R + γ max a Q(S' , a) − Q(S, A)] - A nextAction = this->getLearningAction(nextState, actionSet); + spAction nextAction = this->getLearningAction(nextState, actionSet); ReinforcementLearning::updateStateAction(currentStateAction, StateAction(nextState, nextAction), diff --git a/include/algorithm/QLearningET.h b/include/algorithm/QLearningET.h index b109e93..53f8e3b 100644 --- a/include/algorithm/QLearningET.h +++ b/include/algorithm/QLearningET.h @@ -39,18 +39,18 @@ class QLearningET final: public EligibilityTraces, public QLearning // Inherited. virtual void update(const StateAction& currentStateAction, - const S& nextState, + const spState& nextState, const rl::FLOAT currentStateActionValue, - const set& actionSet); + const spActionSet& actionSet); private: }; template void QLearningET::update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) { - A nextAction = this->getLearningAction(nextState, actionSet); + const spAction& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) { + spAction nextAction = this->getLearningAction(nextState, actionSet); ReinforcementLearning::updateStateAction( currentStateAction, StateAction(nextState, nextAction), diff --git a/include/algorithm/ReinforcementLearning.h b/include/algorithm/ReinforcementLearning.h index 51a58c6..21ea157 100644 --- a/include/algorithm/ReinforcementLearning.h +++ b/include/algorithm/ReinforcementLearning.h @@ -51,7 +51,7 @@ class ReinforcementLearning : public LearningAlgorithm { * @param actionSet a set of possible actions. * @return the action that will "likely" gives the highest reward. */ - A argMax(const S& state, const set& actionSet) const; + spAction argMax(const spState& state, const spActionSet& actionSet) const; /** * @return current discount rate. @@ -90,7 +90,7 @@ class ReinforcementLearning : public LearningAlgorithm { * @param actionSet Set of actions. * @return Action with respect to learning/offline policy. */ - A getLearningAction(const S& currentState, const set& actionSet); + spAction getLearningAction(const spState& currentState, const spActionSet& actionSet); /** * @param stateAction to acquire a value of. @@ -110,6 +110,7 @@ class ReinforcementLearning : public LearningAlgorithm { * @return state-action pair container. */ const StateActionPairContainer& getStateActionPairContainer() const; + StateActionPairContainer& getStateActionPairContainer(); /** * @param stateActionPairContainer set state-action pair container. @@ -122,11 +123,11 @@ class ReinforcementLearning : public LearningAlgorithm { virtual void updateStateAction(const StateAction ¤tStateAction, const StateAction &nextStateAction, const rl::FLOAT reward); - virtual A getAction(const S& currentState, const set& actionSet); + virtual spAction getAction(const spState& currentState, const spActionSet& actionSet); protected: - void _buildActionValueMap(const set& actionSet, const S& currentState, - map& actionValueMap); + void _buildActionValueMap(const spActionSet& actionSet, const spState& currentState, + spActionValueMap& actionValueMap); protected: rl::FLOAT _stepSize; rl::FLOAT _discountRate; @@ -154,18 +155,18 @@ ReinforcementLearning::ReinforcementLearning( } template -A ReinforcementLearning::argMax( - const S& state, const set& actionSet) const { - A greedAct = *(actionSet.begin()); +spAction ReinforcementLearning::argMax( + const spState& state, const spActionSet& actionSet) const { + spAction greedAct = *(actionSet.begin()); rl::FLOAT currentValue = this->_defaultStateActionValue; try { - currentValue = this->_stateActionPairContainer[{state, greedAct}]; + currentValue = this->_stateActionPairContainer[StateAction(state, greedAct)]; } catch (StateActionNotExistException& e) { // Do nothing. We already assign it the default state-action value. } - for (const A& action : actionSet) { + for (const spState& action : actionSet) { rl::FLOAT value = this->_defaultStateActionValue; try { value = this->_stateActionPairContainer[StateAction(state, action)]; @@ -209,6 +210,12 @@ inline const StateActionPairContainer& ReinforcementLearning< return _stateActionPairContainer; } +template +inline StateActionPairContainer& ReinforcementLearning< + S, A>::getStateActionPairContainer() { + return _stateActionPairContainer; +} + template inline void ReinforcementLearning::setStateActionPairContainer( const StateActionPairContainer& stateActionPairContainer) { @@ -222,31 +229,31 @@ inline void ReinforcementLearning::setStateActionValue( } template -inline A ReinforcementLearning::getLearningAction( - const S& currentState, const set& actionSet) { +inline spAction ReinforcementLearning::getLearningAction( + const spState& currentState, const spActionSet& actionSet) { _stateActionPairContainer.addState(currentState, this->_defaultStateActionValue, actionSet); - map actionValueMap; + spActionValueMap actionValueMap; _buildActionValueMap(actionSet, currentState, actionValueMap); return this->_getLearningPolicyAction(actionValueMap, actionSet); } template void ReinforcementLearning::_buildActionValueMap( - const set& actionSet, const S& currentState, - map& actionValueMap) { - for (const A& action : actionSet) { - actionValueMap[action] = _stateActionPairContainer[StateAction( - currentState, action)]; + const spActionSet& actionSet, const spState& currentState, + spActionValueMap& actionValueMap) { + for (const spAction& action : actionSet) { + actionValueMap[action] = _stateActionPairContainer[ + StateAction(currentState, action)]; } } template -A ReinforcementLearning::getAction( - const S& currentState, const set& actionSet) { +spAction ReinforcementLearning::getAction( + const spState& currentState, const spActionSet& actionSet) { _stateActionPairContainer.addState(currentState, this->_defaultStateActionValue, actionSet); - map actionValueMap; + spActionValueMap actionValueMap; _buildActionValueMap(actionSet, currentState, actionValueMap); return this->_controlPolicy->getAction(actionValueMap, actionSet); } diff --git a/include/algorithm/Sarsa.h b/include/algorithm/Sarsa.h index c9c1383..a4a6e77 100644 --- a/include/algorithm/Sarsa.h +++ b/include/algorithm/Sarsa.h @@ -47,12 +47,14 @@ class Sarsa : public ReinforcementLearning { // Inherited. virtual void update(const SA& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet); + const spState& nextState, + const rl::FLOAT reward, + const spActionSet& actionSet); }; template -Sarsa::Sarsa(rl::FLOAT stepSize, rl::FLOAT discountRate, +Sarsa::Sarsa(rl::FLOAT stepSize, + rl::FLOAT discountRate, policy::Policy& policy) : ReinforcementLearning(stepSize, discountRate, policy) { this->setLearningPolicy(policy); @@ -61,9 +63,9 @@ Sarsa::Sarsa(rl::FLOAT stepSize, rl::FLOAT discountRate, // TODO: Make setLearningPolicy and setPolicy the same. template void Sarsa::update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) { - A nextAction = this->getAction(nextState, actionSet); + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) { + spAction nextAction = this->getAction(nextState, actionSet); ReinforcementLearning::updateStateAction(currentStateAction, SA(nextState, nextAction), reward); this->backUpStateActionPair(currentStateAction, reward, SA(nextState, nextAction)); } diff --git a/include/algorithm/SarsaET.h b/include/algorithm/SarsaET.h index c9fca9c..b7322d7 100644 --- a/include/algorithm/SarsaET.h +++ b/include/algorithm/SarsaET.h @@ -46,15 +46,15 @@ class SarsaET final: public EligibilityTraces, public Sarsa { policy::Policy& policy, rl::FLOAT lambda); virtual void update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) override; + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) override; }; template void SarsaET::update(const StateAction& currentStateAction, - const S& nextState, const rl::FLOAT reward, - const set& actionSet) { - A nextAction = this->getAction(nextState, actionSet); + const spState& nextState, const rl::FLOAT reward, + const spActionSet& actionSet) { + spAction nextAction = this->getAction(nextState, actionSet); ReinforcementLearning::updateStateAction( currentStateAction, StateAction(nextState, nextAction), diff --git a/include/algorithm/gradient-descent/GradientDescent.h b/include/algorithm/gradient-descent/GradientDescent.h index 0a0991a..442792d 100644 --- a/include/algorithm/gradient-descent/GradientDescent.h +++ b/include/algorithm/gradient-descent/GradientDescent.h @@ -7,14 +7,13 @@ #ifndef _GRADIENT_DESCENT_H_ #define _GRADIENT_DESCENT_H_ -#include "../../declares.h" - #include #include #include #include #include +#include "../../declares.h" #include "../../coding/TileCode.h" using namespace std; @@ -51,7 +50,7 @@ class GradientDescent { * @param parameters * @return corresponding value. */ - FLOAT getValueFromParameters(const vector& parameters) const; + FLOAT getValueFromParameters(const floatVector& parameters) const; /** * Get the value of the parameters in the real space. @@ -65,7 +64,7 @@ class GradientDescent { * @param fv feature vector output. Feature vector are samples taken around * the parameters in the n-dimension tilecde. */ - FEATURE_VECTOR getFeatureVector(const vector& parameters) const; + FEATURE_VECTOR getFeatureVector(const floatVector& parameters) const; /** * Increase the eligibility traces of a given feature vector. @@ -96,9 +95,9 @@ class GradientDescent { * @param nextStateVector array of next states. * @param reward reward for taking nextAction. */ - void updateWeights(const STATE_CONT& currentStateVector, - const ACTION_CONT& currentActionVector, - const STATE_CONT& nextStateVector, + void updateWeights(const spStateCont& currentStateVector, + const spActionCont& currentActionVector, + const spStateCont& nextStateVector, const FLOAT nextActionValue, const FLOAT reward); /** @@ -108,16 +107,16 @@ class GradientDescent { * @param maxAction max action calculated while building action value map. */ void buildActionValues( - const set& actionSet, const vector& param, - map& actionVectorValueMap, - ACTION_CONT& maxAction) const; + const spActionSet& actionSet, const spStateCont& param, + spActionValueMap& actionVectorValueMap, + spActionCont& maxAction) const; /** * @param actionValueMap state-action to value mapping. * @return value. */ FLOAT getMaxValue( - const map& actionValueMap) const; + const spActionValueMap& actionValueMap) const; /** * Update weights with tderror. diff --git a/include/algorithm/gradient-descent/ReinforcementLearningGD.h b/include/algorithm/gradient-descent/ReinforcementLearningGD.h index 7ed97e4..c589130 100644 --- a/include/algorithm/gradient-descent/ReinforcementLearningGD.h +++ b/include/algorithm/gradient-descent/ReinforcementLearningGD.h @@ -24,7 +24,7 @@ using namespace coding; /*! \class ReinforcementLearningGD * \brief Gradient descent implementation of Reinforcement Learning. */ -class ReinforcementLearningGD : public LearningAlgorithm { +class ReinforcementLearningGD : public LearningAlgorithm { public: /** * @param tileCode tileCode implementation to be aggregated. @@ -37,7 +37,7 @@ class ReinforcementLearningGD : public LearningAlgorithm& policy); + rl::FLOAT lambda, policy::Policy& policy); /** * @param currentStateAction current state-action vector to apply. @@ -46,24 +46,24 @@ class ReinforcementLearningGD : public LearningAlgorithm& currentStateAction, - const STATE_CONT& nextStateVector, const FLOAT reward, - const set& actionSet); + const StateAction& currentStateAction, + const spStateCont& nextStateVector, const FLOAT reward, + const spActionSet& actionSet); /** * @param state State to take action to. * @param actionSet Set of possible actions. * @return Action determined by Control Policy. */ - virtual ACTION_CONT getAction(const STATE_CONT& state, - const set& actionSet); + virtual spActionCont getAction(const spStateCont& state, + const spActionSet& actionSet); /** * @param stateAction State-action pair to determine value of. * @return value of state-actio pair. */ virtual FLOAT getStateActionValue( - const StateAction& stateAction); + const StateAction& stateAction); /** * Reset routine for an algorithm for every episode. @@ -77,7 +77,7 @@ class ReinforcementLearningGD : public LearningAlgorithm& stateAction); + const StateAction& stateAction); /** * @param actionSet set of possible actions. @@ -85,10 +85,10 @@ class ReinforcementLearningGD : public LearningAlgorithm& actionSet, - const vector& nextState, - map& actionValueMap, - ACTION_CONT& maxAction); + void _buildActionValues(const spActionSet& actionSet, + const spStateCont& nextState, + spActionValueMap& actionValueMap, + spActionCont& maxAction); protected: GradientDescent _gradientDescent; diff --git a/include/coding/TileCode.h b/include/coding/TileCode.h index efbb5ba..5729e33 100644 --- a/include/coding/TileCode.h +++ b/include/coding/TileCode.h @@ -44,7 +44,7 @@ class TileCode { * @param parameters * @param Vector of "discretize" index. */ - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters) = 0; + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters) = 0; /** * @return size of the weight vector. diff --git a/include/coding/TileCodeCorrect.h b/include/coding/TileCodeCorrect.h index 336f1ec..cbb5477 100644 --- a/include/coding/TileCodeCorrect.h +++ b/include/coding/TileCodeCorrect.h @@ -39,7 +39,7 @@ class TileCodeCorrect : public TileCode { TileCodeCorrect(vector >& dimensionalInfos, size_t numTilings); - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters) override; + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters) override; }; diff --git a/include/coding/TileCodeMt1993764.h b/include/coding/TileCodeMt1993764.h index 563c5be..eec5e31 100644 --- a/include/coding/TileCodeMt1993764.h +++ b/include/coding/TileCodeMt1993764.h @@ -40,7 +40,7 @@ class TileCodeMt1993764 : public TileCode { * @param parameters * @return Vector of discretize index. */ - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters) override; + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters) override; protected: std::mt19937_64 _prng; diff --git a/include/coding/TileCodeMurMur.h b/include/coding/TileCodeMurMur.h index 7258ba4..2e91a99 100644 --- a/include/coding/TileCodeMurMur.h +++ b/include/coding/TileCodeMurMur.h @@ -20,7 +20,7 @@ class TileCodeMurMur : public TileCode { TileCodeMurMur(vector >& dimensionalInfos, size_t numTilings); TileCodeMurMur(vector >& dimensionalInfos, size_t numTilings, size_t sizeHint); - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters) override; + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters) override; }; } // namespace Coding diff --git a/include/coding/TileCodeSuperFastHash.h b/include/coding/TileCodeSuperFastHash.h index a253f20..9d930a3 100644 --- a/include/coding/TileCodeSuperFastHash.h +++ b/include/coding/TileCodeSuperFastHash.h @@ -39,7 +39,7 @@ class TileCodeSuperFastHash : public TileCode { * @param parameters * @return Vector of discretize index. */ - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters); + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters); }; } // namespace Coding diff --git a/include/coding/TileCodeUNH.h b/include/coding/TileCodeUNH.h index dd3fb38..987f0d0 100644 --- a/include/coding/TileCodeUNH.h +++ b/include/coding/TileCodeUNH.h @@ -30,7 +30,7 @@ class TileCodeUNH : public TileCode { * @param parameters * @return Vector of discretize index. */ - virtual FEATURE_VECTOR getFeatureVector(const STATE_CONT& parameters); + virtual FEATURE_VECTOR getFeatureVector(const floatVector& parameters); size_t mod(size_t n, size_t k) { return (n >= 0) ? n % k : k - 1 - ((-n - 1) % k); diff --git a/include/declares.h b/include/declares.h index 7992e2c..65d3cad 100644 --- a/include/declares.h +++ b/include/declares.h @@ -9,10 +9,13 @@ #include #include +#include +#include +#include +#include +#include #define DEBUG -//#define TEST_PRINT -//#define TEST_TILE_CODING #ifndef DEBUG #define NDEBUG // Gets rid of assertions. @@ -30,19 +33,145 @@ const FLOAT DEFAULT_GREEDINESS_LEARNING_POLICY = 1.0F; // Hashing constants. const UINT MURMUR_HASH_SEED = 0x666; +/** + * TODO(jandres): Add this to some factory generated module that sets settings. + */ const UINT MAX_EPISODES = 100000; -typedef std::vector STATE_CONT; //!< State container. -typedef std::vector ACTION_CONT; //!< Action container. - /*! \typedef FEATURE_VECTOR * Feature vector is a data structure for Tile Coding. It is the indices * that contains the data points to be sampled. + * + * TODO(jandres): See #RL-14 */ typedef std::vector FEATURE_VECTOR; +/*! \typedef floatFector + * \brief A vector of float. + * + * TODO(jandres): See #RL-14 + */ +using floatVector = std::vector; + +/*! \typedef spFloatVector + * \brief Wraps floatVector in shared_ptr. + * + * TODO(jandres): See #RL-14 + */ +using spFloatVector = std::shared_ptr>; + +/*! \typedef StateAndReward + * \tparam S data type of the State. + * + * Represent the pair of state and the corresponding reward. + */ template using StateAndReward = std::pair; -} +/*! \typedef spState + * \tparam S data type of the State. + * + * Wraps states in shared_ptr. This makes development easier since + * as we pass the states around, we are guaranteed they still exist. + */ +template +using spState = std::shared_ptr; + +/*! \typedef spStateComp + * \tparam S data type of the State. + * + * Comparison for spState should be by the value they dereference, + * not their pointer address value. + */ +template +class spStateComp { + public: + bool operator()(const spState& x, const spState& y) const { + return *x < *y; + } +}; + +/*! \typedef spAction + * \tparam A data data type of the action. + * + * Wraps states in shared_ptr. This makes development easier since + * as we pass the action around, we are guaranteed they still exist. + */ +template +using spAction = std::shared_ptr; + +/*! \typedef spActionComp + * \tparam A data type of the action. + * + * Comparison for spAction should be by the value they dereference, + * not their pointer address value. + */ +template +class spActionComp { + public: + bool operator()(const spAction& x, const spAction& y) const { + return *x < *y; + } +}; + +/*! \typedef stateCont + * \brief State for Gradient descent learning. + * + * TODO(jandres): See #RL-14 + */ +typedef std::vector stateCont; //!< State container. + +/*! \typedef actionCont + * \brief Action for Gradient descent learning. + * + * TODO(jandres): See #RL-14 + */ +typedef std::vector actionCont; //!< Action container. + +/*! \typedef spStateCont + * \brief Wraps stateCont in shared_ptr. + */ +using spStateCont = spState; + +/*! \typedef spActionCont + * \brief Wraps actionCont in shared_ptr. + */ +using spActionCont = spState; + +/*! \typedef spStateAndReward + * \brief A pair of spState and its corresponding reward. + */ +template +using spStateAndReward = std::pair, FLOAT>; + +/*! \typedef spStateXMap + * \tparam S data type of state. + * \tparam X data type of arbitrary data spState is mapping to. + * \brief A mapping of spState to some arbitrary data. + */ +template +using spStateXMap = std::map, X, spStateComp>; + +/*! \typedef spStateSet + * \tparam S data type of the state. + * \brief A set of spState with an appropriate comparison object. + */ +template +using spStateSet = std::set, spStateComp>; + +/*! \typedef spActionSet + * \tparam A data type of the action. + * \brief A set of spAction with an appropriate comparison object. + */ +template +using spActionSet = std::set, spActionComp>; + +/*! \typedef spActionValueMap + * \param A data type of the action. + * \brief A map of spAction and it's corresponding value. + */ +template +using spActionValueMap = std::map, FLOAT, spActionComp>; + +} // namespace rl diff --git a/include/policy/EpsilonGreedy.h b/include/policy/EpsilonGreedy.h index 9158f53..cd9ce38 100644 --- a/include/policy/EpsilonGreedy.h +++ b/include/policy/EpsilonGreedy.h @@ -5,16 +5,15 @@ * Created on June 6, 2014, 6:11 PM */ -#ifndef EPSILONGREEDY_H -#define EPSILONGREEDY_H - -#include "../declares.h" +#pragma once #include #include +#include -#include "Policy.h" +#include "../declares.h" #include "../agent/StateAction.h" +#include "Policy.h" using namespace std; @@ -49,12 +48,18 @@ class EpsilonGreedy : public Policy { * @param actionSet a set of possible actions. * @return the action that will "likely" gives the highest reward. */ - A argMax(const map& actionValues, const set& actionSet) const; - - virtual A getAction(const map& actionValues, const set& actionSet); - - virtual A getAction(const map& actionValues, const set& actionSet, - const A& maxAction); + spAction argMax( + const spActionValueMap& actionValues, + const spActionSet& actionSet) const; + + spAction getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet) override; + + spAction getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet, + const spAction& maxAction) override; /** * @param greediness @@ -71,26 +76,29 @@ class EpsilonGreedy : public Policy { rl::FLOAT _greediness; //!< Probability of selecting a greedy action. }; -typedef EpsilonGreedy, vector > EpsilonGreedySL; +typedef EpsilonGreedy, vector> EpsilonGreedySL; template -rl::policy::EpsilonGreedy::EpsilonGreedy(rl::FLOAT greediness) +EpsilonGreedy::EpsilonGreedy(rl::FLOAT greediness) : _greediness(greediness), _distribution(0.0F, 1.0F) { } template -rl::policy::EpsilonGreedy::~EpsilonGreedy() { +EpsilonGreedy::~EpsilonGreedy() { } template -A rl::policy::EpsilonGreedy::getAction( - const map& actionValues, const set& actionSet) { - if(_greediness==1.0F) return argMax(actionValues, actionSet); +spAction EpsilonGreedy::getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet) { + if (_greediness == 1.0F) { + return argMax(actionValues, actionSet); + } const rl::FLOAT& r = _distribution(_randomDevice); if (r > _greediness) { uniform_int_distribution indexDistribution(0, actionSet.size()); - typename set::const_iterator it(actionSet.begin()); + typename spActionSet::const_iterator it(actionSet.begin()); advance(it, indexDistribution(_randomDevice)); return (*it); } else { @@ -99,14 +107,17 @@ A rl::policy::EpsilonGreedy::getAction( } template -A rl::policy::EpsilonGreedy::getAction( - const map& actionValues, - const set& actionSet, const A& maxAction){ - if(_greediness == 1.0F) return maxAction; +spAction EpsilonGreedy::getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet, + const spAction& maxAction) { + if (_greediness == 1.0F) { + return maxAction; + } const rl::FLOAT& r = _distribution(_randomDevice); if (r > _greediness) { uniform_int_distribution indexDistribution(0, actionSet.size()); - typename set::const_iterator it(actionSet.begin()); + typename spActionSet::const_iterator it(actionSet.begin()); advance(it, indexDistribution(_randomDevice)); return (*it); } else { @@ -115,8 +126,9 @@ A rl::policy::EpsilonGreedy::getAction( } template -A rl::policy::EpsilonGreedy::argMax( - const map& actionValues, const set& actionSet) const { +spAction EpsilonGreedy::argMax( + const spActionValueMap& actionValues, + const spActionSet& actionSet) const { auto maxActionIter = actionSet.begin(); rl::FLOAT maxVal = actionValues.at(*maxActionIter); for (auto iter = actionSet.begin(); iter != actionSet.end(); ++iter) { @@ -130,18 +142,16 @@ A rl::policy::EpsilonGreedy::argMax( } template -void rl::policy::EpsilonGreedy::setGreediness( +void EpsilonGreedy::setGreediness( rl::FLOAT greediness) { this->_greediness = greediness; } template -rl::FLOAT rl::policy::EpsilonGreedy::getGreediness() const { +rl::FLOAT EpsilonGreedy::getGreediness() const { return _greediness; } -} // Policy -} // rl - -#endif /* EPSILONGREEDY_H */ +} // namespace policy +} // namespace rl diff --git a/include/policy/Policy.h b/include/policy/Policy.h index 8235243..12bdd33 100644 --- a/include/policy/Policy.h +++ b/include/policy/Policy.h @@ -5,14 +5,12 @@ * Created on June 6, 2014, 5:48 PM */ -#ifndef POLICY_H -#define POLICY_H - -#include "../declares.h" +#pragma once #include #include +#include "../declares.h" #include "../agent/StateAction.h" using namespace std; @@ -41,8 +39,10 @@ class Policy { * @return action given a mapping of actions and their value and a * set of actions. */ - virtual A getAction(const map& actionValues, - const set& actionSet) = 0; + + virtual spAction getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet) = 0; /** * Returns action given a mapping of actions and their value and a @@ -54,8 +54,9 @@ class Policy { * @return action given a mapping of actions and their value and a * set of actions. */ - virtual A getAction(const map& actionValues, - const set& actionSet, const A& maxAction) = 0; + virtual spAction getAction(const spActionValueMap& actionValues, + const spActionSet& actionSet, + const rl::spAction& maxAction) = 0; private: }; @@ -67,10 +68,7 @@ class Policy { * and action space, PolicySL is a typedef of Policy specifically for that * purpose. */ -typedef Policy PolicySL; - -} /* Policy */ -} /* rl */ - -#endif /* POLICY_H */ +typedef Policy PolicySL; +} // namespace policy +} // rl diff --git a/include/policy/Softmax.h b/include/policy/Softmax.h index f581055..d0e6f4a 100644 --- a/include/policy/Softmax.h +++ b/include/policy/Softmax.h @@ -50,10 +50,10 @@ class Softmax : public Policy { public: Softmax(rl::FLOAT temperature); - virtual A getAction(const map& actionValues, - const set& actionSet) override; - virtual A getAction(const map& actionValues, - const set& actionSet, const A& maxAction) override; + virtual spAction getAction(const spActionValueMap& actionValues, + const spActionSet& actionSet) override; + virtual spAction getAction(const spActionValueMap& actionValues, + const spActionSet& actionSet, const spAction& maxAction) override; private: std::random_device _randomDevice; std::uniform_real_distribution _distribution; @@ -67,22 +67,22 @@ typedef Softmax, vector > SoftmaxSL; template rl::policy::Softmax::Softmax(rl::FLOAT temperature) - : _distribution(0.0F, 1.0F) { - _temperature = temperature; + : _distribution(0.0F, 1.0F), + _temperature(temperature) { } template -A rl::policy::Softmax::getAction( - const map& actionValues, const set& actionSet) { +spAction rl::policy::Softmax::getAction( + const spActionValueMap& actionValues, const spActionSet& actionSet) { // Acquire E(i=1...n) e^(Q(i)/temp) rl::FLOAT sum = 0.0F; - for (const A& action : actionSet) { + for (const spAction& action : actionSet) { sum += exp(actionValues.at(action) / _temperature); } // Acquire probability for each action. - map actionProbabilityMap; - for (const A& action : actionSet) { + spActionValueMap actionProbabilityMap; + for (const spAction& action : actionSet) { rl::FLOAT probability = exp(actionValues.at(action) / _temperature) / sum; actionProbabilityMap[action] = probability; } @@ -110,8 +110,9 @@ A rl::policy::Softmax::getAction( } template -A rl::policy::Softmax::getAction(const map& actionValues, - const set& actionSet, const A& maxAction){ +spAction rl::policy::Softmax::getAction( + const spActionValueMap& actionValues, + const spActionSet& actionSet, const spAction& maxAction){ return getAction(actionValues, actionSet); } diff --git a/src/algorithm/gradient-descent/ReinforcementLearningGD.cpp b/src/algorithm/gradient-descent/ReinforcementLearningGD.cpp index 3f4ad12..c70d28b 100644 --- a/src/algorithm/gradient-descent/ReinforcementLearningGD.cpp +++ b/src/algorithm/gradient-descent/ReinforcementLearningGD.cpp @@ -9,26 +9,27 @@ namespace algorithm { ReinforcementLearningGD::ReinforcementLearningGD( TileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, - rl::FLOAT lambda, policy::Policy& policy) - : LearningAlgorithm(policy), + rl::FLOAT lambda, policy::Policy& policy) + : LearningAlgorithm(policy), _gradientDescent(tileCode, stepSize, discountRate, lambda) { } void ReinforcementLearningGD::_buildActionValues( - const set& actionSet, const STATE_CONT& nextState, - map& actionValueMap, - ACTION_CONT& action) { + const spActionSet& actionSet, + const spStateCont& nextState, + spActionValueMap& actionValueMap, + spActionCont& action) { _gradientDescent.buildActionValues(actionSet, nextState, actionValueMap, action); } void ReinforcementLearningGD::update( - const StateAction& currentStateAction, - const STATE_CONT& nextStateVector, const FLOAT reward, - const set& actionSet) { - map actionValueMap; - ACTION_CONT maxAction; + const StateAction& currentStateAction, + const spStateCont& nextStateVector, const FLOAT reward, + const spActionSet& actionSet) { + spActionValueMap actionValueMap; + spActionCont maxAction(new actionCont); _buildActionValues(actionSet, nextStateVector, actionValueMap, maxAction); - const ACTION_CONT& nextAction = this->_getLearningPolicyAction( + const spActionCont& nextAction = this->_getLearningPolicyAction( actionValueMap, actionSet, maxAction); _gradientDescent.updateWeights(currentStateAction.getState(), @@ -37,29 +38,29 @@ void ReinforcementLearningGD::update( reward); } -ACTION_CONT ReinforcementLearningGD::getAction( - const STATE_CONT& state, const set& actionSet) { - map actionValueMap; - ACTION_CONT maxAction; +spActionCont ReinforcementLearningGD::getAction( + const spStateCont& state, const spActionSet& actionSet) { + spActionValueMap actionValueMap; + spActionCont maxAction(new actionCont); _buildActionValues(actionSet, state, actionValueMap, maxAction); return this->_learningPolicy->getAction(actionValueMap, actionSet, maxAction); } FLOAT ReinforcementLearningGD::getStateActionValue( - const StateAction& stateAction) { + const StateAction& stateAction) { return _getStateActionValue(stateAction); } FLOAT ReinforcementLearningGD::_getStateActionValue( - const StateAction& stateAction) { - vector copy = stateAction.getState(); - copy.insert(copy.end(), stateAction.getAction().begin(), - stateAction.getAction().end()); + const StateAction& stateAction) { + auto copy = *(stateAction.getState()); + copy.insert(copy.end(), stateAction.getAction()->begin(), + stateAction.getAction()->end()); return _gradientDescent.getValueFromParameters(copy); } void ReinforcementLearningGD::reset() { - LearningAlgorithm::reset(); + LearningAlgorithm::reset(); _gradientDescent.resetEligibilityTraces(); } diff --git a/src/algorithm/gradient-descent/SarsaETGD.cpp b/src/algorithm/gradient-descent/SarsaETGD.cpp index b9e48e5..923d4c5 100644 --- a/src/algorithm/gradient-descent/SarsaETGD.cpp +++ b/src/algorithm/gradient-descent/SarsaETGD.cpp @@ -9,7 +9,7 @@ namespace algorithm { SarsaETGD::SarsaETGD( TileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, - rl::FLOAT lambda, policy::Policy, vector >& policy) + rl::FLOAT lambda, policy::Policy& policy) : ReinforcementLearningGD(tileCode, stepSize, discountRate, lambda, policy) { this->setLearningPolicy(policy); } diff --git a/src/coding/TileCodeCorrect.cpp b/src/coding/TileCodeCorrect.cpp index 6d372d2..d092f05 100644 --- a/src/coding/TileCodeCorrect.cpp +++ b/src/coding/TileCodeCorrect.cpp @@ -13,8 +13,8 @@ TileCodeCorrect::TileCodeCorrect(vector >& dimensionalInfos } FEATURE_VECTOR TileCodeCorrect::getFeatureVector( - const STATE_CONT& parameters) { - assert(this->getDimension() == parameters.size()); + const floatVector& parameters) { + assert(this->getDimension() == parameters.size()); FEATURE_VECTOR fv; fv.resize(this->_numTilings); @@ -23,7 +23,7 @@ FEATURE_VECTOR TileCodeCorrect::getFeatureVector( rl::INT hashedIndex = 0; rl::INT mult = 1; for (size_t j = 0; j < this->getDimension(); j++) { - hashedIndex += this->paramToGridValue(parameters[j], i, j) * mult; + hashedIndex += this->paramToGridValue(parameters.at(j), i, j) * mult; mult *= this->_dimensionalInfos[j].GetGridCountReal(); } diff --git a/src/coding/TileCodeMt1993764.cpp b/src/coding/TileCodeMt1993764.cpp index 9d8c332..056883d 100644 --- a/src/coding/TileCodeMt1993764.cpp +++ b/src/coding/TileCodeMt1993764.cpp @@ -22,13 +22,13 @@ TileCodeMt1993764::TileCodeMt1993764( } FEATURE_VECTOR TileCodeMt1993764::getFeatureVector( - const STATE_CONT& parameters) { + const floatVector& parameters) { vector tileComponents(this->getDimension() + 1); FEATURE_VECTOR fv; for (size_t i = 0; i < this->_numTilings; i++) { for (size_t j = 0; j < this->getDimension(); j++) { - tileComponents[j] = this->paramToGridValue(parameters[j], i, j); + tileComponents[j] = this->paramToGridValue(parameters.at(j), i, j); } // Add a unique number_tiling identifier. diff --git a/src/coding/TileCodeMurMur.cpp b/src/coding/TileCodeMurMur.cpp index 2510601..4f2b6da 100644 --- a/src/coding/TileCodeMurMur.cpp +++ b/src/coding/TileCodeMurMur.cpp @@ -22,14 +22,14 @@ TileCodeMurMur::TileCodeMurMur( } FEATURE_VECTOR TileCodeMurMur::getFeatureVector( - const STATE_CONT& parameters) { + const floatVector& parameters) { assert(this->getDimension() == parameters.size()); FEATURE_VECTOR fv; vector tileComponents(this->getDimension() + 1); for (rl::INT i = 0; i < this->_numTilings; i++) { for (size_t j = 0; j < this->getDimension(); j++) { - tileComponents[j] = this->paramToGridValue(parameters[j], i, j); + tileComponents[j] = this->paramToGridValue(parameters.at(j), i, j); } // Add a unique number_tiling identifier. diff --git a/src/coding/TileCodeSuperFastHash.cpp b/src/coding/TileCodeSuperFastHash.cpp index e433168..b0ec0d5 100644 --- a/src/coding/TileCodeSuperFastHash.cpp +++ b/src/coding/TileCodeSuperFastHash.cpp @@ -23,13 +23,13 @@ TileCodeSuperFastHash::TileCodeSuperFastHash(vector >& dime } FEATURE_VECTOR TileCodeSuperFastHash::getFeatureVector( - const STATE_CONT& parameters) { + const floatVector& parameters) { vector tileComponents(this->getDimension() + 1); FEATURE_VECTOR fv; for (size_t i = 0; i < this->_numTilings; i++) { for (size_t j = 0; j < this->getDimension(); j++) { - tileComponents[j] = this->paramToGridValue(parameters[j], i, j); + tileComponents[j] = this->paramToGridValue(parameters.at(j), i, j); } // Add a unique number_tiling identifier. diff --git a/src/coding/TileCodeUNH.cpp b/src/coding/TileCodeUNH.cpp index 15047d3..efd58f1 100644 --- a/src/coding/TileCodeUNH.cpp +++ b/src/coding/TileCodeUNH.cpp @@ -39,7 +39,7 @@ TileCodeUNH::TileCodeUNH( } FEATURE_VECTOR TileCodeUNH::getFeatureVector( - const STATE_CONT& parameters) { + const floatVector& parameters) { assert(this->getDimension() == parameters.size()); FEATURE_VECTOR fv; @@ -48,7 +48,7 @@ FEATURE_VECTOR TileCodeUNH::getFeatureVector( for (size_t i = 0; i < this->getDimension(); i++) { // Note to floor since casting to integer is not consistent // with negative number. Casting is always a number toward zero. - qStates[i] = floor(parameters[i] * _normalization[i]); + qStates[i] = floor(parameters.at(i) * _normalization[i]); } for (size_t i = 0; i < this->_numTilings; i++) { diff --git a/src/intrinsic/GradientDescent.cpp b/src/intrinsic/GradientDescent.cpp index fac8ae5..98f06c7 100644 --- a/src/intrinsic/GradientDescent.cpp +++ b/src/intrinsic/GradientDescent.cpp @@ -3,6 +3,7 @@ */ #include +#include #include "algorithm/gradient-descent/GradientDescent.h" @@ -59,7 +60,7 @@ size_t GradientDescent::getSize() const { } FLOAT GradientDescent::getValueFromParameters( - const vector& parameters) const { + const floatVector& parameters) const { FEATURE_VECTOR fv = std::move(_tileCode.getFeatureVector(parameters)); return getValueFromFeatureVector(fv); @@ -142,21 +143,23 @@ void GradientDescent::backUpWeights(FLOAT tdError) { } void GradientDescent::updateWeights( - const STATE_CONT& currentStateVector, - const ACTION_CONT& currentActionVector, - const STATE_CONT& nextStateVector, const FLOAT nextActionValue, + const spStateCont& currentStateVector, + const spActionCont& currentActionVector, + const spStateCont& nextStateVector, + const FLOAT nextActionValue, const FLOAT reward) { vector currentStateVectorCopy; - currentStateVectorCopy.reserve(currentStateVector.size() + - currentActionVector.size()); + currentStateVectorCopy.reserve(currentStateVector->size() + + currentActionVector->size()); currentStateVectorCopy.insert(currentStateVectorCopy.end(), - currentStateVector.begin(), - currentStateVector.end()); + currentStateVector->begin(), + currentStateVector->end()); currentStateVectorCopy.insert(currentStateVectorCopy.end(), - currentActionVector.begin(), - currentActionVector.end()); + currentActionVector->begin(), + currentActionVector->end()); - FEATURE_VECTOR currentStateFv = std::move(getFeatureVector(currentStateVectorCopy)); + FEATURE_VECTOR currentStateFv = + std::move(getFeatureVector(currentStateVectorCopy)); incrementEligibilityTraces(currentStateFv); FLOAT tdError = reward + _discountRate * nextActionValue @@ -167,31 +170,31 @@ void GradientDescent::updateWeights( decreaseEligibilityTraces(); } -FEATURE_VECTOR GradientDescent::getFeatureVector(const vector& parameters) const { +FEATURE_VECTOR GradientDescent::getFeatureVector(const floatVector& parameters) const { return _tileCode.getFeatureVector(parameters); } void GradientDescent::buildActionValues( - const set& actionSet, const vector& nextState, - map& actionVectorValueMap, ACTION_CONT& actions) const { - set::const_iterator maxActionIter = actionSet.begin(); + const spActionSet& actionSet, const spStateCont& nextState, + spActionValueMap& actionVectorValueMap, spActionCont& actions) const { + spActionSet::const_iterator maxActionIter = actionSet.begin(); // Build pc = array. vector pc; - pc.reserve(nextState.size() + (*maxActionIter).size()); - pc.insert(pc.end(), nextState.begin(), nextState.end()); - pc.insert(pc.end(), (*maxActionIter).begin(), (*maxActionIter).end()); + pc.reserve(nextState->size() + (*maxActionIter)->size()); + pc.insert(pc.end(), nextState->begin(), nextState->end()); + pc.insert(pc.end(), (*maxActionIter)->begin(), (*maxActionIter)->end()); FLOAT maxVal = getValueFromParameters(pc); actionVectorValueMap[*maxActionIter] = maxVal; - set::const_iterator iter = maxActionIter; + spActionSet::const_iterator iter = maxActionIter; iter++; for (; iter != actionSet.end(); ++iter) { vector paramCopy; - paramCopy.reserve(nextState.size() + (*iter).size()); - paramCopy.insert(paramCopy.end(), nextState.begin(), nextState.end()); - paramCopy.insert(paramCopy.end(), (*iter).begin(), (*iter).end()); + paramCopy.reserve(nextState->size() + (*iter)->size()); + paramCopy.insert(paramCopy.end(), nextState->begin(), nextState->end()); + paramCopy.insert(paramCopy.end(), (*iter)->begin(), (*iter)->end()); FLOAT value = getValueFromParameters(paramCopy); actionVectorValueMap[*iter] = value; @@ -224,10 +227,12 @@ void GradientDescent::resetEligibilityTraces() { } FLOAT GradientDescent::getMaxValue( - const map& actionValueMap) const { + const spActionValueMap& actionValueMap) const { // Get max action. FLOAT maxValue = actionValueMap.begin()->second; - for (auto iter = actionValueMap.begin(); iter != actionValueMap.end(); ++iter) { + for (auto iter = actionValueMap.begin(); + iter != actionValueMap.end(); + ++iter) { if (iter->second > maxValue) { maxValue = iter->second; } @@ -236,5 +241,5 @@ FLOAT GradientDescent::getMaxValue( return maxValue; } -} // Algorithm -} // rl +} // namespace Algorithm +} // namespace rl diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ce87e6b..85b2d37 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,9 +9,3 @@ file(GLOB SRC_TEST_FILES "src/**/*.cpp") add_executable(testRunner testRunner.cpp ${SRC_TEST_FILES}) target_link_libraries(testRunner rlTestHelper) - -add_custom_command( - TARGET testRunner - POST_BUILD - COMMAND ${PROJECT_BINARY_DIR}/test/testRunner -) diff --git a/test/helper/include/MountainCarEnvironment.h b/test/helper/include/MountainCarEnvironment.h index 870dde5..1fa78f1 100644 --- a/test/helper/include/MountainCarEnvironment.h +++ b/test/helper/include/MountainCarEnvironment.h @@ -20,15 +20,15 @@ using namespace std; namespace rl { -class MountainCarEnvironment : public Environment{ +class MountainCarEnvironment : public Environment{ public: - using SA = Environment::SA; + using SA = Environment::SA; public: - MountainCarEnvironment(Actuator& actuator, Sensor& sensor); + MountainCarEnvironment(Actuator& actuator, Sensor& sensor); // Overloaded methods. - virtual std::pair getNextStateAndReward(const SA& stateAction) override; + virtual rl::spStateAndReward getNextStateAndReward(const SA& stateAction) override; }; } /* namespace rl */ diff --git a/test/helper/include/RandomWalkEnvironment.h b/test/helper/include/RandomWalkEnvironment.h index 5950521..7535a2e 100644 --- a/test/helper/include/RandomWalkEnvironment.h +++ b/test/helper/include/RandomWalkEnvironment.h @@ -14,23 +14,23 @@ using std::map; -const rl::INT A(0), B(1), C(2), D(3), T(4); -const rl::INT L(0), R(1); +const rl::spState A(new rl::INT(0)), B(new rl::INT(1)), C(new rl::INT(2)), D(new rl::INT(3)), T(new rl::INT(4)); +const rl::spAction L(new rl::INT(0)), R(new rl::INT(1)); namespace rl { class RandomWalkEnvironment : public Environment{ public: enum State : int { A = 0, B, C, D, T }; - enum Action : int { L = 0, R = 1 }; + enum Action : int { L = 0, R }; public: RandomWalkEnvironment(Actuator& actuator, Sensor& sensor); - virtual std::pair getNextStateAndReward(const StateAction& stateAction) override; + virtual std::pair, FLOAT> getNextStateAndReward(const StateAction& stateAction) override; protected: - map, FLOAT> _env; + map, spState> _env; }; } /* namespace rl */ diff --git a/test/helper/include/SensorMountainCar.h b/test/helper/include/SensorMountainCar.h index d0d6d64..f7863f4 100644 --- a/test/helper/include/SensorMountainCar.h +++ b/test/helper/include/SensorMountainCar.h @@ -15,13 +15,14 @@ using namespace std; namespace rl { -class SensorMountainCar final : - public Sensor { -public: +namespace agent { +class SensorMountainCar final : + public Sensor { + public: SensorMountainCar(); - virtual bool isTerminalState(const STATE_CONT& stateData) const; + virtual bool isTerminalState(const spStateCont &stateData) const; }; - +} } /* namespace rl */ #endif /* SENSORMOUNTAINCAR_H_ */ diff --git a/test/helper/include/SensorRandomWalk.h b/test/helper/include/SensorRandomWalk.h index 953fa52..d5daca0 100644 --- a/test/helper/include/SensorRandomWalk.h +++ b/test/helper/include/SensorRandomWalk.h @@ -11,12 +11,12 @@ #include "rl" namespace rl { - -class SensorRandomWalk : public SensorDiscrete { +namespace agent { +class SensorRandomWalk : public SensorDiscrete { public: SensorRandomWalk(); }; - +} } /* namespace rl */ #endif /* SENSORRANDOMWALK_H_ */ diff --git a/test/helper/src/MountainCarEnvironment.cpp b/test/helper/src/MountainCarEnvironment.cpp index 654bc1f..8a5fb23 100644 --- a/test/helper/src/MountainCarEnvironment.cpp +++ b/test/helper/src/MountainCarEnvironment.cpp @@ -13,11 +13,11 @@ using namespace std; namespace rl{ rl::MountainCarEnvironment::MountainCarEnvironment( - Actuator& actuator, Sensor& sensor) : - rl::agent::Environment(actuator, sensor) { + Actuator& actuator, Sensor& sensor) : + rl::agent::Environment(actuator, sensor) { } -std::pair MountainCarEnvironment::getNextStateAndReward( +rl::spStateAndReward MountainCarEnvironment::getNextStateAndReward( const MountainCarEnvironment::SA& stateAction) { auto act = stateAction.getAction(); @@ -27,7 +27,7 @@ std::pair MountainCarEnvironment::getNextStateAndReward( }*/ auto nextState = stateAction.getState(); - rl::INT copyAct = act[0] - 1; + rl::INT copyAct = act->at(0) - 1; rl::FLOAT nextReward = -1; if (copyAct == 0) { nextReward = -1; @@ -35,20 +35,20 @@ std::pair MountainCarEnvironment::getNextStateAndReward( nextReward = -2; } - nextState[VEL] += - (0.001F * copyAct - 0.0025F * cos(3.0F * nextState[POS])); + nextState->at(VEL) += + (0.001F * copyAct - 0.0025F * cos(3.0F * nextState->at(POS))); - if (nextState[VEL] < -0.07F) - nextState[VEL] = -0.07F; - else if (nextState[VEL] >= 0.07F) - nextState[VEL] = 0.06999999F; - nextState[POS] += nextState[VEL]; - if (nextState[POS] >= 0.5F) { - nextState[POS] = 0.5; + if (nextState->at(VEL) < -0.07F) + nextState->at(VEL) = -0.07F; + else if (nextState->at(VEL) >= 0.07F) + nextState->at(VEL) = 0.06999999F; + nextState->at(POS) += nextState->at(VEL); + if (nextState->at(POS) >= 0.5F) { + nextState->at(POS) = 0.5; nextReward = 0; - } else if (nextState[POS] < -1.2F) { - nextState[POS] = -1.2F; - nextState[VEL] = 0.0F; + } else if (nextState->at(POS) < -1.2F) { + nextState->at(POS) = -1.2F; + nextState->at(VEL) = 0.0F; } diff --git a/test/helper/src/RandomWalkEnvironment.cpp b/test/helper/src/RandomWalkEnvironment.cpp index 33f95e5..eb42c1c 100644 --- a/test/helper/src/RandomWalkEnvironment.cpp +++ b/test/helper/src/RandomWalkEnvironment.cpp @@ -11,21 +11,21 @@ namespace rl { RandomWalkEnvironment::RandomWalkEnvironment( Actuator& actuator, Sensor& sensor) : Environment(actuator, sensor) { - _env[rl::agent::StateAction(A, L)] = T; - _env[rl::agent::StateAction(A, R)] = B; - _env[rl::agent::StateAction(B, L)] = A; - _env[rl::agent::StateAction(B, R)] = C; - _env[rl::agent::StateAction(C, L)] = B; - _env[rl::agent::StateAction(C, R)] = D; - _env[rl::agent::StateAction(D, L)] = C; - _env[rl::agent::StateAction(D, R)] = T; + _env[rl::agent::StateAction(::A, ::L)] = ::T; + _env[rl::agent::StateAction(::A, ::R)] = ::B; + _env[rl::agent::StateAction(::B, ::L)] = ::A; + _env[rl::agent::StateAction(::B, ::R)] = ::C; + _env[rl::agent::StateAction(::C, ::L)] = ::B; + _env[rl::agent::StateAction(::C, ::R)] = ::D; + _env[rl::agent::StateAction(::D, ::L)] = ::C; + _env[rl::agent::StateAction(::D, ::R)] = ::T; } -std::pair RandomWalkEnvironment::getNextStateAndReward(const StateAction& stateAction) { - INT nextState = this->_env[stateAction]; +std::pair, FLOAT> RandomWalkEnvironment::getNextStateAndReward(const StateAction& stateAction) { + spState nextState = this->_env[stateAction]; FLOAT nextReward = -1.0F; - if (nextState == T) { + if (*nextState == *(::T)) { nextReward = 0.0F; } diff --git a/test/helper/src/SensorMountainCar.cpp b/test/helper/src/SensorMountainCar.cpp index 496ec33..64ebc9c 100644 --- a/test/helper/src/SensorMountainCar.cpp +++ b/test/helper/src/SensorMountainCar.cpp @@ -12,20 +12,20 @@ using namespace std; namespace rl { - +namespace agent { SensorMountainCar::SensorMountainCar() : - Sensor(STATE_CONT(2, 0)) { + Sensor(rl::spStateCont(new rl::stateCont(2, 0))) { } bool SensorMountainCar::isTerminalState( - const STATE_CONT& stateData) const { - if (std::abs(stateData[0] - 0.50F) <= 0.01F) { + const spStateCont &stateData) const { + if (std::abs(stateData->at(0) - 0.50F) <= 0.01F) { return true; } return false; } - +} } /* namespace rl */ diff --git a/test/helper/src/SensorRandomWalk.cpp b/test/helper/src/SensorRandomWalk.cpp index 83b7436..b807c40 100644 --- a/test/helper/src/SensorRandomWalk.cpp +++ b/test/helper/src/SensorRandomWalk.cpp @@ -6,9 +6,11 @@ #include "../include/RandomWalkEnvironment.h" namespace rl { +namespace agent { SensorRandomWalk::SensorRandomWalk() : - SensorDiscrete(RandomWalkEnvironment::State::B){ + SensorDiscrete(spState(new INT(RandomWalkEnvironment::State::B))) { } } +} diff --git a/test/src/agent/ActionSet_test.cpp b/test/src/agent/ActionSet_test.cpp index 1a107b4..a36beae 100644 --- a/test/src/agent/ActionSet_test.cpp +++ b/test/src/agent/ActionSet_test.cpp @@ -23,7 +23,12 @@ SCENARIO("Agent have a storage for actions.", } GIVEN("ActionSet instance with some actions initialize.") { - ActionSet as({ 1, 2, 3, 4 }); + rl::spAction action1(new rl::INT(1)); + rl::spAction action2(new rl::INT(2)); + rl::spAction action3(new rl::INT(3)); + rl::spAction action4(new rl::INT(4)); + + ActionSet as(rl::spActionSet({ action1, action2, action3, action4 })); WHEN ("I first access action.") { auto actions = as.getActionSet(); diff --git a/test/src/agent/ActuatorBase_test.cpp b/test/src/agent/ActuatorBase_test.cpp index 97f7cd2..d1452bd 100644 --- a/test/src/agent/ActuatorBase_test.cpp +++ b/test/src/agent/ActuatorBase_test.cpp @@ -13,14 +13,9 @@ using namespace std; SCENARIO("Agent have a storage for actions and actuator.", "[rl::agent::Actuator]") { GIVEN("Actuator") { - rl::agent::Actuator actuator( - { - rl::RandomWalkEnvironment::Action::L, - rl::RandomWalkEnvironment::Action::R - } - ); + rl::agent::Actuator actuator(rl::spActionSet({ L, R })); rl::SensorRandomWalk srw; - srw.addTerminalState(rl::RandomWalkEnvironment::State::T); + srw.addTerminalState(T); rl::RandomWalkEnvironment rwe(actuator, srw); diff --git a/test/src/agent/AgentSupervised_test.cpp b/test/src/agent/AgentSupervised_test.cpp index da6423e..9d3d418 100644 --- a/test/src/agent/AgentSupervised_test.cpp +++ b/test/src/agent/AgentSupervised_test.cpp @@ -11,18 +11,24 @@ using namespace std; SCENARIO("Supervised agent develop an accurate model of the environment.", "[rl::agent::AgentSupervised]") { GIVEN("A binary environment in which 1 is good and 0 is bad.") { + rl::spState state0(new int(0)); + rl::spState state1(new int(1)); + + rl::spAction action0(new int(0)); + rl::spAction action1(new int(1)); + rl::policy::EpsilonGreedy policy(1.0F); rl::algorithm::Sarsa sarsaAlgorithm(0.1F, 0.9F, policy); - auto actionSet = rl::agent::ActionSet({0, 1}); + auto actionSet = rl::agent::ActionSet({action0, action1}); rl::agent::AgentSupervised supevisedAgent(actionSet, sarsaAlgorithm); WHEN ("When I train 1 to be good and 0 to be bad") { - supevisedAgent.train(1, 1, 1000, 1); // We don't transition anywhere. It's just being in state 1 is good. - supevisedAgent.train(0, 0, -1000, 0); // Same deal. + supevisedAgent.train(state1, action1, 1000, state1); // We don't transition anywhere. It's just being in state 1 is good. + supevisedAgent.train(state0, action0, -1000, state0); // Same deal. THEN ("Agent should know that 1 should be good and 0 should be bad") { - auto value1 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(1, 1)); - auto value0 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(0, 0)); + auto value1 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(state1, action1)); + auto value0 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(state0, action0)); REQUIRE(value1 > value0); } diff --git a/test/src/agent/Agent_test.cpp b/test/src/agent/Agent_test.cpp index 5837928..fb7d665 100644 --- a/test/src/agent/Agent_test.cpp +++ b/test/src/agent/Agent_test.cpp @@ -16,14 +16,9 @@ using namespace std; SCENARIO("Agent interacts with the environment as expected.", "[rl::agent::Agent]") { GIVEN("A binary environment in which 1 is good and 0 is bad.") { - rl::agent::Actuator arw( - { - rl::RandomWalkEnvironment::Action::L, - rl::RandomWalkEnvironment::Action::R - } - ); + rl::agent::Actuator arw(rl::spActionSet({ L, R })); rl::SensorRandomWalk srw; - srw.addTerminalState(rl::RandomWalkEnvironment::State::T); + srw.addTerminalState(T); rl::RandomWalkEnvironment rwe(arw, srw); @@ -33,10 +28,10 @@ SCENARIO("Agent interacts with the environment as expected.", rl::agent::Agent agent(rwe, dynaQAlgorithm); WHEN ("When I move left.") { - REQUIRE(agent.getLastObservedState() == rl::RandomWalkEnvironment::State::B); - agent.applyAction(rl::RandomWalkEnvironment::Action::L); + REQUIRE(*(agent.getLastObservedState()) == *B); + agent.applyAction(L); THEN ("I move from B to A.") { - REQUIRE(agent.getLastObservedState() == rl::RandomWalkEnvironment::State::A); + REQUIRE(*(agent.getLastObservedState()) == *A); } } } diff --git a/test/src/agent/Environment_test.cpp b/test/src/agent/Environment_test.cpp index 4e1d7b6..a26295b 100644 --- a/test/src/agent/Environment_test.cpp +++ b/test/src/agent/Environment_test.cpp @@ -7,18 +7,24 @@ using namespace std; SCENARIO("Environment represents the thing agent interacts with.", "[rl::agent::AgentSupervised]") { GIVEN("A binary environment in which 1 is good and 0 is bad.") { + rl::spState state0(new int(0)); + rl::spState state1(new int(1)); + + rl::spAction action0(new int(0)); + rl::spAction action1(new int(1)); + rl::policy::EpsilonGreedy policy(1.0F); rl::algorithm::Sarsa sarsaAlgorithm(0.1F, 0.9F, policy); - auto actionSet = rl::agent::ActionSet({0, 1}); + auto actionSet = rl::agent::ActionSet(rl::spActionSet({action0, action1})); rl::agent::AgentSupervised supevisedAgent(actionSet, sarsaAlgorithm); WHEN ("When I train 1 to be good and 0 to be bad") { - supevisedAgent.train(1, 1, 1000, 1); // We don't transition anywhere. It's just being in state 1 is good. - supevisedAgent.train(0, 0, -1000, 0); // Same deal. + supevisedAgent.train(state1, action1, 1000, state1); // We don't transition anywhere. It's just being in state 1 is good. + supevisedAgent.train(state0, action0, -1000, state0); // Same deal. THEN ("Agent should know that 1 should be good and 0 should be bad") { - auto value1 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(1, 1)); - auto value0 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(0, 0)); + auto value1 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(state1, action1)); + auto value0 = sarsaAlgorithm.getStateActionValue(rl::agent::StateAction(state0, action0)); REQUIRE(value1 > value0); } diff --git a/test/src/agent/Sensor_test.cpp b/test/src/agent/Sensor_test.cpp index b281afd..4480160 100644 --- a/test/src/agent/Sensor_test.cpp +++ b/test/src/agent/Sensor_test.cpp @@ -16,14 +16,9 @@ using namespace std; SCENARIO("Sensor for representing the states in the environment.", "[rl::agent::Sensor]") { GIVEN("Sensor instance") { - rl::agent::Actuator arw( - { - rl::RandomWalkEnvironment::Action::L, - rl::RandomWalkEnvironment::Action::R - } - ); + rl::agent::Actuator arw({ L, R }); rl::SensorRandomWalk srw; - srw.addTerminalState(rl::RandomWalkEnvironment::State::T); + srw.addTerminalState(T); rl::RandomWalkEnvironment rwe(arw, srw); @@ -33,10 +28,10 @@ SCENARIO("Sensor for representing the states in the environment.", rl::agent::Agent agent(rwe, dynaQAlgorithm); WHEN ("When I move left.") { - REQUIRE(srw.getLastObservedState() == rl::RandomWalkEnvironment::State::B); - agent.applyAction(rl::RandomWalkEnvironment::Action::L); + REQUIRE(*(srw.getLastObservedState()) == *B); + agent.applyAction(L); THEN ("I move from B to A.") { - REQUIRE(srw.getLastObservedState() == rl::RandomWalkEnvironment::State::A); + REQUIRE(*(srw.getLastObservedState()) == *A); } } } diff --git a/test/src/agent/StateActionPairContainer_test.cpp b/test/src/agent/StateActionPairContainer_test.cpp index 2bcdea5..c191d1a 100644 --- a/test/src/agent/StateActionPairContainer_test.cpp +++ b/test/src/agent/StateActionPairContainer_test.cpp @@ -12,11 +12,23 @@ using namespace std; SCENARIO("StateActionPairContainer a storage for state-action pair.", "[rl::agent::StateActionPairContainer]") { + rl::spState state0(new int(0)); + rl::spState state1(new int(1)); + rl::spState state2(new int(2)); + rl::spState state3(new int(3)); + rl::spState state4(new int(4)); + + rl::spAction action1(new int(1)); + rl::spAction action2(new int(2)); + rl::spAction action3(new int(3)); + rl::spAction action4(new int(4)); + rl::spAction action5(new int(5)); + GIVEN("Empty rl::agent::StateActionPairContainer instance") { rl::agent::StateActionPairContainer container; WHEN ("Adding a state-action.") { - container.addStateAction(rl::agent::StateAction(0, 1), 1); + container.addStateAction(rl::agent::StateAction(state0, action1), 1); THEN ("Size should be 1p") { REQUIRE(container.getMap().size() == 1); } @@ -27,9 +39,11 @@ SCENARIO("StateActionPairContainer a storage for state-action pair.", rl::agent::StateActionPairContainer container; WHEN ("Calling addSate with 5 actions.") { - container.addState(5, -1, std::set({ - 1, 2, 3, 4, 5 - })); + container.addState(state0, -1, rl::spActionSet( + { + action1, action2, action3, action4, action5 + } + )); THEN ("Changes the count to 5.") { REQUIRE(container.getMap().size() == 5); } @@ -39,12 +53,12 @@ SCENARIO("StateActionPairContainer a storage for state-action pair.", GIVEN("Empty rl::agent::StateActionPairContainer") { rl::agent::StateActionPairContainer container; - WHEN ("Calling addAction with 5 actions.") { - container.addAction(5, -1, std::set( + WHEN ("Calling addAction with 5 states.") { + container.addAction(action5, -1, rl::spStateSet( { - 1, 2, 3, 4, 5 - }) - ); + state0, state1, state2, state3, state4 + } + )); THEN ("Changes the count to 5.") { REQUIRE(container.getMap().size() == 5); } @@ -53,13 +67,13 @@ SCENARIO("StateActionPairContainer a storage for state-action pair.", GIVEN("A non-empty rl::agent::StateActionPairContainer") { rl::agent::StateActionPairContainer container; - container.addStateAction(rl::agent::StateAction(0, 1), 1); + container.addStateAction(rl::agent::StateAction(state0, action1), 1); WHEN ("Calling setStateActionValue.") { - REQUIRE(container.getStateActionValue(rl::agent::StateAction(0, 1)) == 1); - container.setStateActionValue(rl::agent::StateAction(0, 1), 2); + REQUIRE(container.getStateActionValue(rl::agent::StateAction(state0, action1)) == 1); + container.setStateActionValue(rl::agent::StateAction(state0, action1), 2); THEN ("Changes the value of an existing state action.") { - REQUIRE(container.getStateActionValue(rl::agent::StateAction(0, 1)) == 2); + REQUIRE(container.getStateActionValue(rl::agent::StateAction(state0, action1)) == 2); } } } diff --git a/test/src/agent/StateActionPairValueComparison_test.cpp b/test/src/agent/StateActionPairValueComparison_test.cpp index 5a7bd99..f3c9787 100644 --- a/test/src/agent/StateActionPairValueComparison_test.cpp +++ b/test/src/agent/StateActionPairValueComparison_test.cpp @@ -14,11 +14,16 @@ using namespace std; SCENARIO("rl::agent::StateActionPairValueComparison can compaire two state action pair.", "[rl::agent::StateActionPairValueComparison]") { GIVEN("rl::agent::StateActionPairValueComparison") { + auto B = rl::spState(new int(rl::RandomWalkEnvironment::State::B)); + auto L = rl::spAction(new int(rl::RandomWalkEnvironment::Action::L)); + + auto A = rl::spState(new int(rl::RandomWalkEnvironment::State::A)); + rl::agent::StateActionPairValueComparison sapvc; auto sav1 = std::pair, double>( - rl::agent::StateAction(rl::RandomWalkEnvironment::State::B, rl::RandomWalkEnvironment::Action::L), 1.0F); + rl::agent::StateAction(B, L), 1.0F); auto sav2 = std::pair, double>( - rl::agent::StateAction(rl::RandomWalkEnvironment::State::A, rl::RandomWalkEnvironment::Action::L), 2.0F); + rl::agent::StateAction(A, L), 2.0F); WHEN ("I compare sa1 and sa2 using rl::agent::StateActionPairValueComparison.") { THEN ("It should return sa1 > sa2.") { diff --git a/test/src/agent/StateActionTransition_test.cpp b/test/src/agent/StateActionTransition_test.cpp index 3c69d83..54082af 100644 --- a/test/src/agent/StateActionTransition_test.cpp +++ b/test/src/agent/StateActionTransition_test.cpp @@ -26,20 +26,18 @@ SCENARIO("StateActionTransition have the ability to represent offline model for } GIVEN("A set of states") { - rl::INT state01(1); - rl::INT state02(2); - rl::INT state03(3); + rl::spState state01(new rl::INT(1)); + rl::spState state02(new rl::INT(2)); + rl::spState state03(new rl::INT(3)); WHEN("Update all of them. One more than the other.") { sat.update(state01, 10); REQUIRE(sat.getSize() == 1); REQUIRE(sat.getNextState() == state01); // 1 states stored. - rl::INT state02(2); sat.update(state02, 10); sat.update(state02, 10); REQUIRE(sat.getSize() == 2); // 2 states stored. - rl::INT state03(3); sat.update(state03, 10); sat.update(state03, 10); sat.update(state03, 10); @@ -51,7 +49,7 @@ SCENARIO("StateActionTransition have the ability to represent offline model for state03OccurenceCount = 0; for (rl::UINT i = 0; i < 1000; i++) { - const rl::INT& state = sat.getNextState(); + rl::spState state = sat.getNextState(); if (state == state01) { state01OccurenceCount++; } else if (state == state02) { @@ -78,7 +76,7 @@ SCENARIO("StateActionTransition have the ability to represent offline model for bool exceptionCalled = false; try { - sat.getReward(rl::INT(23)); + sat.getReward(rl::spState(new rl::INT(23))); } catch (StateActionTransitionException& exception) { //cout << exception.what() << endl; exceptionCalled = true; @@ -101,11 +99,13 @@ SCENARIO("StateActionTransition have the ability to represent offline model for } WHEN("I try to access a reward for a state that was not know to this StateActionTransition") { - sat.update(10, 100); + rl::spState state10(new rl::INT(10)); + rl::spState state23(new rl::INT(23)); + sat.update(state10, 100); THEN ("I get an exception.") { bool exceptionCalled = false; try { - sat.getReward(23); + sat.getReward(state23); } catch (StateActionTransitionException& exception) { //cout << exception.what() << endl; exceptionCalled = true; diff --git a/test/src/agent/StateAction_test.cpp b/test/src/agent/StateAction_test.cpp index 8cca7c7..d788c8c 100644 --- a/test/src/agent/StateAction_test.cpp +++ b/test/src/agent/StateAction_test.cpp @@ -7,20 +7,26 @@ #include "../../lib/catch.hpp" SCENARIO("StateAction data type","[rl::agent::StateAction") { + auto state01 = rl::spState(new rl::INT(10)); + auto action01 = rl::spAction(new rl::INT(10)); + + auto state02 = rl::spState(new rl::INT(20)); + auto action02 = rl::spAction(new rl::INT(20)); + GIVEN("StateAction instance") { - rl::agent::StateAction stateAction(10, 10); + rl::agent::StateAction stateAction(state01, action01); WHEN("Accessor method are invoked") { THEN("Retrieves appropriate value") { - REQUIRE(stateAction.getState() == 10); - REQUIRE(stateAction.getAction() == 10); + REQUIRE(stateAction.getState() == state01); + REQUIRE(stateAction.getAction() == action01); } } } GIVEN("Multiple StateAction instance") { - rl::agent::StateAction stateAction01(10, 10); - rl::agent::StateAction stateAction02(10, 10); - rl::agent::StateAction stateAction03(20, 10); + rl::agent::StateAction stateAction01(state01, action01); + rl::agent::StateAction stateAction02(state01, action01); + rl::agent::StateAction stateAction03(state02, action01); WHEN("Equality operator is invoked") { THEN("Returns true.") { REQUIRE(stateAction01 == stateAction02); @@ -31,10 +37,10 @@ SCENARIO("StateAction data type","[rl::agent::StateAction") { } GIVEN("Multiple StateAction instance") { - rl::agent::StateAction stateAction01(10, 10); - rl::agent::StateAction stateAction02(10, 20); - rl::agent::StateAction stateAction03(10, 20); - rl::agent::StateAction stateAction04(20, 10); + rl::agent::StateAction stateAction01(state01, action01); + rl::agent::StateAction stateAction02(state01, action02); + rl::agent::StateAction stateAction03(state01, action02); + rl::agent::StateAction stateAction04(state02, action01); WHEN("Less operator is invoked") { THEN("Compares lexically .") { REQUIRE(stateAction01 != stateAction02); diff --git a/test/src/algorithm/DynaQETWatkins_test.cpp b/test/src/algorithm/DynaQETWatkins_test.cpp index cb8b412..e6263c8 100644 --- a/test/src/algorithm/DynaQETWatkins_test.cpp +++ b/test/src/algorithm/DynaQETWatkins_test.cpp @@ -19,7 +19,7 @@ using namespace std; SCENARIO("DynaQTETWatkins converge to a solution", "[rl::DynaQTETWatkins]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; // Setup sensor. srw.addTerminalState(T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. diff --git a/test/src/algorithm/DynaQPrioritizedSweeping_test.cpp b/test/src/algorithm/DynaQPrioritizedSweeping_test.cpp index c2f5d95..117f5d2 100644 --- a/test/src/algorithm/DynaQPrioritizedSweeping_test.cpp +++ b/test/src/algorithm/DynaQPrioritizedSweeping_test.cpp @@ -19,7 +19,7 @@ using namespace std; SCENARIO("DynaQPrioritizedSweeping converge to a solution", "[rl::DynaQPrioritizedSweeping]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; // Setup sensor. srw.addTerminalState(T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. diff --git a/test/src/algorithm/DynaQ_test.cpp b/test/src/algorithm/DynaQ_test.cpp index a82c536..729fe0f 100644 --- a/test/src/algorithm/DynaQ_test.cpp +++ b/test/src/algorithm/DynaQ_test.cpp @@ -17,14 +17,9 @@ using namespace std; SCENARIO("DynaQ converge to a solution", "[rl::DynaQ]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw( - { - rl::RandomWalkEnvironment::Action::L, - rl::RandomWalkEnvironment::Action::R - } - ); + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; - srw.addTerminalState(rl::RandomWalkEnvironment::State::T); + srw.addTerminalState(T); rl::RandomWalkEnvironment rwe(arw, srw); diff --git a/test/src/algorithm/QLearningETWatkins_test.cpp b/test/src/algorithm/QLearningETWatkins_test.cpp index a85d21a..c53156c 100644 --- a/test/src/algorithm/QLearningETWatkins_test.cpp +++ b/test/src/algorithm/QLearningETWatkins_test.cpp @@ -19,7 +19,7 @@ using namespace std; SCENARIO("QLearningETWatkins converge to a solution", "[rl::QLearningETWatkins]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; // Setup sensor. srw.addTerminalState(T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. diff --git a/test/src/algorithm/QLearning_test.cpp b/test/src/algorithm/QLearning_test.cpp index 65a196d..3bf7777 100644 --- a/test/src/algorithm/QLearning_test.cpp +++ b/test/src/algorithm/QLearning_test.cpp @@ -19,7 +19,7 @@ using namespace std; SCENARIO("QLearning converge to a solution", "[rl::QLearning]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; // Setup sensor. srw.addTerminalState(T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. diff --git a/test/src/algorithm/SarsaET_test.cpp b/test/src/algorithm/SarsaET_test.cpp index 6e51d6a..04526f6 100644 --- a/test/src/algorithm/SarsaET_test.cpp +++ b/test/src/algorithm/SarsaET_test.cpp @@ -19,7 +19,7 @@ using namespace std; SCENARIO("SarsaET converge to a solution", "[rl::SarsaET]") { GIVEN("A random walk environment") { - rl::agent::Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. rl::SensorRandomWalk srw; // Setup sensor. srw.addTerminalState(T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. diff --git a/test/src/algorithm/Sarsa_test.cpp b/test/src/algorithm/Sarsa_test.cpp index e471813..f015e93 100644 --- a/test/src/algorithm/Sarsa_test.cpp +++ b/test/src/algorithm/Sarsa_test.cpp @@ -20,9 +20,9 @@ using namespace std; SCENARIO("Sarsa converge to a solution", "[rl::Sarsa]") { GIVEN("A random walk environment") { - Actuator arw({L, R}); // Setup actuator with actions. + rl::agent::Actuator arw(rl::spActionSet({ L, R })); // Setup actuator with actions. SensorRandomWalk srw; // Setup sensor. - srw.addTerminalState(T); // Setup terminal state. + srw.addTerminalState(::T); // Setup terminal state. rl::RandomWalkEnvironment rwEnv(arw, srw); // Setup environment. policy::EpsilonGreedy policy(1.0F); diff --git a/test/src/algorithm/gradient-descent/QLearningETGD_test.cpp b/test/src/algorithm/gradient-descent/QLearningETGD_test.cpp index 5fbddc7..53dee82 100644 --- a/test/src/algorithm/gradient-descent/QLearningETGD_test.cpp +++ b/test/src/algorithm/gradient-descent/QLearningETGD_test.cpp @@ -24,12 +24,14 @@ SCENARIO("Q-learning Eligibility Traces and Gradient Descent converge to a solut "[rl::QLearingETGD]") { GIVEN("A Mountain Car environment") { // Actions. - rl::agent::Actuator amc( - { - {0}, // Reverse. - {1}, // Neutral. - {2} // Forward. - } + rl::agent::Actuator amc( + nrl::spActionSet( + { + {0}, // Reverse. + {1}, // Neutral. + {2} // Forward. + } + ) ); rl::SensorMountainCar smc; // Setup sensor. rl::MountainCarEnvironment mce(amc, smc); // Setup environment. diff --git a/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp b/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp index b618458..ed692a5 100644 --- a/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp +++ b/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp @@ -24,12 +24,14 @@ SCENARIO("Sarsa Eligibility Traces and Gradient Descent converge to a solution", "[rl::SarsaETGD]") { GIVEN("A Mountain Car environment") { // Actions. - rl::agent::Actuator amc( - { - {0}, // Reverse. - {1}, // Neutral. - {2} // Forward. - } + rl::agent::Actuator amc( + nrl::spActionSet( + { + {0}, // Reverse. + {1}, // Neutral. + {2} // Forward. + } + ) ); rl::SensorMountainCar smc; // Setup sensor. rl::MountainCarEnvironment mce(amc, smc); // Setup environment. @@ -38,11 +40,11 @@ SCENARIO("Sarsa Eligibility Traces and Gradient Descent converge to a solution", // Setup tile coding. vector > dimensionalInfoVector = { - rl::coding::DimensionInfo(-1.2F, 0.5F, 10), // Velocity. - rl::coding::DimensionInfo(-0.07F, 0.07F, 10), // Position. + rl::coding::DimensionInfo(-1.2F, 0.5F, 20), // Velocity. + rl::coding::DimensionInfo(-0.07F, 0.07F, 20), // Position. rl::coding::DimensionInfo(0.0F, 2.0F, 3, 0.0F), // Action dimension. }; - rl::coding::TileCodeCorrect tileCode(dimensionalInfoVector, 10); // Setup tile coding with 10 offsets. + rl::coding::TileCodeCorrect tileCode(dimensionalInfoVector, 20); // Setup tile coding with 10 offsets. rl::algorithm::SarsaETGD sarsa(tileCode, 0.1F, 1.0F, 0.9F, policy); rl::agent::AgentSL agent(mce, sarsa); diff --git a/test/src/coding/TileCode_test.cpp b/test/src/coding/TileCode_test.cpp index 67d580f..5f99fa3 100644 --- a/test/src/coding/TileCode_test.cpp +++ b/test/src/coding/TileCode_test.cpp @@ -48,7 +48,7 @@ SCENARIO("Tile code retrieves the correct feature vector", WHEN ("TileCode::getFeatureVector is called for (-0.4, -0.4).") { THEN ("Return (0, 16, 32, 48)") { - auto fv = tileCode.getFeatureVector(STATE_CONT({-0.4, -0.4})); + auto fv = tileCode.getFeatureVector(stateCont({-0.4, -0.4})); decltype(fv) result { 0, 16, 32, 48 }; REQUIRE(fv == result); } diff --git a/test/src/policy/Softmax_test.cpp b/test/src/policy/Softmax_test.cpp index 3609171..e157f15 100644 --- a/test/src/policy/Softmax_test.cpp +++ b/test/src/policy/Softmax_test.cpp @@ -18,16 +18,16 @@ SCENARIO("Softmax action selection probability increases as the value of the act "[rl::Policy::Softmax]") { policy::Softmax policy(0.9F); - rl::INT dummyState(1); - rl::INT action01(1); - rl::INT action02(2); - rl::INT action03(3); - rl::INT action04(4); - set actionSet; - actionSet.insert(action01); - actionSet.insert(action02); - actionSet.insert(action03); - actionSet.insert(action04); + rl::spState dummyState(new rl::INT(1)); + rl::spAction action01(new rl::INT(1)); + rl::spAction action02(new rl::INT(2)); + rl::spAction action03(new rl::INT(3)); + rl::spAction action04(new rl::INT(4)); + spActionSet actionSet( + { + action01, action02, action03, action04 + } + ); using StateActionII = StateAction; @@ -38,22 +38,22 @@ SCENARIO("Softmax action selection probability increases as the value of the act stateActionPairValueMap[StateActionII(dummyState, action04)] = 2.2F; GIVEN("The following action-value map") { - map actionValueMap; - for (const rl::INT &action : actionSet) { + spActionValueMap actionValueMap; + for (auto &action : actionSet) { actionValueMap[action] = stateActionPairValueMap.at(StateActionII(dummyState, action)); } rl::UINT action01Ctr(0), action02Ctr(0), action03Ctr(0), action04Ctr(0); for (rl::INT i = 0; i < 1000; i++) { - rl::INT action = policy.getAction(actionValueMap, actionSet); + auto action = policy.getAction(actionValueMap, actionSet); - if (action == action01) { + if (*action == *action01) { action01Ctr++; - } else if (action == action02) { + } else if (*action == *action02) { action02Ctr++; - } else if (action == action03) { + } else if (*action == *action03) { action03Ctr++; - } else if (action == action04) { + } else if (*action == *action04) { action04Ctr++; } else { assert(false);