Skip to content

Commit

Permalink
add differents elements of the mcts cpp algorithm. Currently, error w…
Browse files Browse the repository at this point in the history
…ith the terminal state
  • Loading branch information
Alphadius committed Jan 11, 2024
1 parent 5380aeb commit e65aa0f
Show file tree
Hide file tree
Showing 16 changed files with 799 additions and 45 deletions.
2 changes: 1 addition & 1 deletion blocking/src/CurvedBlockingClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ std::vector<std::pair<TCellID ,double>>
CurvedBlockingClassifier::list_Possible_Cuts()
{
std::vector<std::pair<TCellID ,double>> list_actions;
auto no_capt_elements = classify();
auto no_capt_elements = this->classify();
auto no_points_capt = no_capt_elements.non_captured_points;
auto no_curves_capt = no_capt_elements.non_captured_curves;

Expand Down
5 changes: 4 additions & 1 deletion rlBlocking/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(GMDS_INC
inc/gmds/rlBlocking/MCTSMove.h
inc/gmds/rlBlocking/MCTSMovePolycube.h
inc/gmds/rlBlocking/MCTSStatePolycube.h
inc/gmds/rlBlocking/MCTSAgent.h
)
set(GMDS_SRC
src/BlockingQuality.cpp
Expand All @@ -23,7 +24,9 @@ set(GMDS_SRC
src/MCTSTree.cpp
src/MCTSAlgorithm.cpp
src/MCTSMovePolycube.cpp
src/MCTSStatePolycube.cpp)
src/MCTSStatePolycube.cpp
src/MCTSAgent.cpp
)
#==============================================================================
add_library(${GMDS_LIB} ${GMDS_INC} ${GMDS_SRC})
#==============================================================================
Expand Down
24 changes: 24 additions & 0 deletions rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef GMDS_MCTSAGENT_H
#define GMDS_MCTSAGENT_H

#include <gmds/rlBlocking/MCTSTree.h>
/*----------------------------------------------------------------------------------------*/
namespace gmds {
/*----------------------------------------------------------------------------------------*/
class LIB_GMDS_RLBLOCKING_API MCTSAgent
{
// example of an agent based on the MCTS_tree. One can also use the tree directly.
MCTSTree *tree;
int max_iter, max_seconds, max_same_quality;

public:
MCTSAgent(MCTSState *starting_state, int max_iter = 100000, int max_seconds = 30, int max_same_quality=3);
~MCTSAgent();
const MCTSMove *genmove();
const MCTSState *get_current_state() const;
void feedback() const {tree->print_stats();}
};
}
/*----------------------------------------------------------------------------------------*/
#endif // GMDS_MCTSAGENT_H
/*----------------------------------------------------------------------------------------*/
19 changes: 12 additions & 7 deletions rlBlocking/inc/gmds/rlBlocking/MCTSAlgorithm.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//
// Created by bourmaudp on 02/12/22.
//
/*----------------------------------------------------------------------------------------*/
#ifndef GMDS_MCTSALGORITHM_H
#define GMDS_MCTSALGORITHM_H
/*----------------------------------------------------------------------------------------*/
#include "LIB_GMDS_RLBLOCKING_export.h"
#include <gmds/rlBlocking/MCTSTree.h>
#include <gmds/rlBlocking/MCTSAgent.h>
#include <gmds/rlBlocking/MCTSStatePolycube.h>
/*----------------------------------------------------------------------------------------*/
namespace gmds {
/*----------------------------------------------------------------------------------------*/
Expand All @@ -14,26 +14,31 @@ namespace gmds {
*/
class LIB_GMDS_RLBLOCKING_API MCTSAlgorithm
{
MCTSTree *tree;
int max_iter, max_seconds,max_same_quality;
public:

/*------------------------------------------------------------------------*/
/** @brief Constructor.
* @param
*/
MCTSAlgorithm();
MCTSAlgorithm(gmds::cad::GeomManager *AGeom,gmds::blocking::CurvedBlocking *ABlocking,int max_iter = 100000, int max_seconds = 30,int max_same_quality = 10);

/*------------------------------------------------------------------------*/
/** @brief Destructor. */
virtual ~MCTSAlgorithm();
~MCTSAlgorithm();

/*------------------------------------------------------------------------*/
/** @brief Performs the MCTS algorithm
*/
void execute();

private:
/** a mesh */
//Mesh* m_mesh;
/** a geom */
gmds::cad::GeomManager *m_geom;
/** a blocking */
gmds::blocking::CurvedBlocking *m_blocking;

};
/*----------------------------------------------------------------------------*/
}
Expand Down
2 changes: 2 additions & 0 deletions rlBlocking/inc/gmds/rlBlocking/MCTSMove.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define GMDS_MCTSMOVE_H
/*----------------------------------------------------------------------------------------*/
#include "LIB_GMDS_RLBLOCKING_export.h"
#include <string>
/*----------------------------------------------------------------------------------------*/
namespace gmds {
/*----------------------------------------------------------------------------------------*/
Expand All @@ -21,6 +22,7 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMove {
/** @brief Overloaded ==
*/
virtual bool operator==(const MCTSMove& AOther) const = 0;
virtual std::string sprint() const { return "Not implemented"; } // and optionally this
};
/*----------------------------------------------------------------------------*/
}
Expand Down
17 changes: 12 additions & 5 deletions rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
//
// Created by bourmaudp on 02/12/22.
//
/*----------------------------------------------------------------------------------------*/
#ifndef GMDS_MCTSMOVE_POLYCUBE_H
#define GMDS_MCTSMOVE_POLYCUBE_H
/*----------------------------------------------------------------------------------------*/
#include "LIB_GMDS_RLBLOCKING_export.h"
#include <gmds/rlBlocking/MCTSMove.h>
#include <gmds/utils/CommonTypes.h>
/*----------------------------------------------------------------------------------------*/
namespace gmds {
/*----------------------------------------------------------------------------------------*/
Expand All @@ -17,11 +15,20 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove {
/*------------------------------------------------------------------------*/
/** @brief Destructor
*/
virtual ~MCTSMovePolycube();
~MCTSMovePolycube();
/*------------------------------------------------------------------------*/
TCellID m_AIdEdge;
TCellID m_AIdBlock;
double m_AParamCut;
/** @brief if typeMove=0: delete block, typeMove=1 cut block
*/
bool m_typeMove;

/** @brief Overloaded ==
*/
virtual bool operator==(const MCTSMove& AOther) const;
MCTSMovePolycube(TCellID AIdEdge,TCellID AIdBlock, double AParamCut,bool ATypeMove);
bool operator==(const MCTSMove& AOther) const;

};
/*----------------------------------------------------------------------------*/
}
Expand Down
17 changes: 16 additions & 1 deletion rlBlocking/inc/gmds/rlBlocking/MCTSState.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/*----------------------------------------------------------------------------------------*/
#include "LIB_GMDS_RLBLOCKING_export.h"
#include <gmds/rlBlocking/MCTSMove.h>
#include <iostream>
/*----------------------------------------------------------------------------------------*/
#include <queue>
/*----------------------------------------------------------------------------------------*/
Expand Down Expand Up @@ -43,12 +44,26 @@ class LIB_GMDS_RLBLOCKING_API MCTSState {
/** @brief Rollout from this state (random simulation)
* @return the rollout status
*/
virtual ROLLOUT_STATUS rollout() const = 0;
virtual double state_rollout() const = 0;
/*------------------------------------------------------------------------*/
/** @brief check the result of a terminal state
* @return the value of the result: Win, Lose, Draw
*/
virtual ROLLOUT_STATUS result_terminal() const = 0;
/*------------------------------------------------------------------------*/
/** @brief Indicate if we have a terminal state (win=true, fail=false)
* @return true if we have a leaf (in the sense of a traditional tree)
*/
virtual bool is_terminal() const = 0;
/*------------------------------------------------------------------------*/
/** @brief Indicate if we have a terminal state (win=true, fail=false)
* @return true if we have a leaf (in the sense of a traditional tree)
*/
virtual double get_quality() const = 0;

virtual void print() const {
std::cout << "Printing not implemented" << std::endl;
}
};
/*----------------------------------------------------------------------------*/
}
Expand Down
53 changes: 45 additions & 8 deletions rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//
// Created by bourmaudp on 02/12/22.
//
/*----------------------------------------------------------------------------------------*/
#ifndef GMDS_MCTSSTATE_POLYCUBE_H
#define GMDS_MCTSSTATE_POLYCUBE_H
/*----------------------------------------------------------------------------------------*/
#include "LIB_GMDS_RLBLOCKING_export.h"
#include <gmds/rlBlocking/MCTSState.h>
#include <gmds/rlBlocking/MCTSMovePolycube.h>
#include <gmds/blocking/CurvedBlockingClassifier.h>
#include <gmds/cadfac/FACManager.h>
/*----------------------------------------------------------------------------------------*/
namespace gmds {
/*----------------------------------------------------------------------------------------*/
Expand All @@ -15,30 +15,67 @@ namespace gmds {
* MCST algorithm
*/
class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{

public:
/*------------------------------------------------------------------------*/
/** @brief Constructore
*/
MCTSStatePolycube(gmds::cad::GeomManager *Ageom, gmds::blocking::CurvedBlocking *ABlocking,
std::vector<double> AHist);
/*------------------------------------------------------------------------*/
/** @brief Destructor
*/
virtual ~MCTSStatePolycube();
~MCTSStatePolycube();
/*------------------------------------------------------------------------*/
/** @brief Gives the set of actions that can be tried from the current state
*/
virtual std::queue<MCTSMove *> *actions_to_try() const ;
std::queue<MCTSMove *> *actions_to_try() const ;
/*------------------------------------------------------------------------*/
/** @brief Performs the @p AMove to change of states
* @param[in] AMove the movement to apply to get to a new state
*/
virtual MCTSState *next_state(const MCTSMove *AMove) const;
MCTSState *next_state(const MCTSMove *AMove) const;
/*------------------------------------------------------------------------*/
/** @brief Rollout from this state (random simulation)
* @return the rollout status
*/
virtual ROLLOUT_STATUS rollout() const;
double state_rollout() const;

/** @brief check the history of qualities
* @return nb of same quality from the history
*/
int check_nb_same_quality() const;
/** @brief check the result of a terminal state
* @return Win = all elements are capt, Lose: parent_quality &lt enfant_quality,
* Draw : same quality for a long time
*/
ROLLOUT_STATUS result_terminal() const;
/*------------------------------------------------------------------------*/
/** @brief Indicate if we have a terminal state (win=true, fail=false)
* @return true if we have a leaf (in the sense of a traditional tree)
*/
virtual bool is_terminal() const;
bool is_terminal() const;
/** @brief return the blocking quality
* */
double get_quality() const;
/** @brief return the geom */
gmds::cad::GeomManager *get_geom();
/** @brief return the current blocking */
gmds::blocking::CurvedBlocking *get_blocking();
/** @brief return the current classifier */
gmds::blocking::CurvedBlockingClassifier *get_class();
/** @brief return the current classification */
gmds::blocking::ClassificationErrors get_errors();
/** @brief return the history of the parents quality */
std::vector<double> get_history() const;

private :
/** @brief the curved blocking of the current state */
gmds::blocking::CurvedBlocking* m_blocking;
gmds::cad::GeomManager* m_geom;
gmds::blocking::CurvedBlockingClassifier* m_class_blocking;
gmds::blocking::ClassificationErrors m_class_errors;
std::vector<double> m_history;
};
/*----------------------------------------------------------------------------*/
}
Expand Down
Loading

0 comments on commit e65aa0f

Please sign in to comment.