diff --git a/blocking/inc/gmds/blocking/CurvedBlocking.h b/blocking/inc/gmds/blocking/CurvedBlocking.h index c4d756a51..05ccc8e3e 100644 --- a/blocking/inc/gmds/blocking/CurvedBlocking.h +++ b/blocking/inc/gmds/blocking/CurvedBlocking.h @@ -54,7 +54,7 @@ struct CellInfo * @param AGeomDim on-classify geometric cell dimension (4 if not classified) * @param AGeomId on-classify geometric cell unique id */ - CellInfo(cad::GeomManager* AManager, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : + CellInfo(cad::GeomManager* AManager=NULL, const int ATopoDim = 4, const int AGeomDim = 4, const int AGeomId = NullID) : topo_dim(ATopoDim), topo_id(m_counter_global_id++), geom_manager(AManager),geom_dim(AGeomDim), geom_id(AGeomId) { } @@ -74,7 +74,7 @@ struct NodeInfo : CellInfo * @param AGeomId on-classify geometric cell unique id * @param APoint geometric location */ - NodeInfo(cad::GeomManager* AManager, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : + NodeInfo(cad::GeomManager* AManager=NULL, const int AGeomDim = 4, const int AGeomId = NullID, const math::Point &APoint = math::Point(0, 0, 0)) : CellInfo(AManager, 0, AGeomDim, AGeomId), point(APoint) { } diff --git a/blocking/tst/ExecutionActionsTestSuite.h b/blocking/tst/ExecutionActionsTestSuite.h index bf32499b4..27711d43d 100644 --- a/blocking/tst/ExecutionActionsTestSuite.h +++ b/blocking/tst/ExecutionActionsTestSuite.h @@ -706,3 +706,84 @@ TEST(ExecutionActionsTestSuite,cb5){ vtk_writer_edges.write("debug_blocking_edges.vtk"); } + + +TEST(ExecutionActionsTestSuite,cb2_auto) { + gmds::cad::FACManager geom_model; + set_up_file(&geom_model,"cb2.vtk"); + gmds::blocking::CurvedBlocking bl(&geom_model,true); + gmds::blocking::CurvedBlockingClassifier classifier(&bl); + + + classifier.clear_classification(); + + auto errors = classifier.classify(); + + + //Check nb points of the geometry and nb nodes of the blocking + ASSERT_EQ(16,geom_model.getNbPoints()); + ASSERT_EQ(24,geom_model.getNbCurves()); + ASSERT_EQ(10,geom_model.getNbSurfaces()); + ASSERT_EQ(8,bl.get_all_nodes().size()); + ASSERT_EQ(12,bl.get_all_edges().size()); + ASSERT_EQ(6,bl.get_all_faces().size()); + + + + //Check elements class and captured + //Check nb nodes/edges/faces no classified + ASSERT_EQ(0,errors.non_classified_nodes.size()); + ASSERT_EQ(0,errors.non_classified_edges.size()); + ASSERT_EQ(6,errors.non_classified_faces.size()); + + //Check nb points/curves/surfaces no captured + ASSERT_EQ(8,errors.non_captured_points.size()); + ASSERT_EQ(12,errors.non_captured_curves.size()); + ASSERT_EQ(10,errors.non_captured_surfaces.size()); + + auto listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + classifier.classify(); + + listEdgesCut = classifier.list_Possible_Cuts(); + //Do 1 cut + bl.cut_sheet(listEdgesCut.front().first,listEdgesCut.front().second); + + + gmds::Mesh m(gmds::MeshModel(gmds::DIM3|gmds::N|gmds::E|gmds::F|gmds::R|gmds::E2N|gmds::F2N|gmds::R2N)); + bl.convert_to_mesh(m); + + + gmds::IGMeshIOService ios(&m); + gmds::VTKWriter vtk_writer(&ios); + vtk_writer.setCellOptions(gmds::N|gmds::R); + vtk_writer.setDataOptions(gmds::N|gmds::R); + vtk_writer.write("cb2_debug_blocking.vtk"); + gmds::VTKWriter vtk_writer_edges(&ios); + vtk_writer_edges.setCellOptions(gmds::N|gmds::E); + vtk_writer_edges.setDataOptions(gmds::N|gmds::E); + vtk_writer_edges.write("cb2_debug_blocking_edges.vtk"); + gmds::VTKWriter vtk_writer_faces(&ios); + vtk_writer_faces.setCellOptions(gmds::N|gmds::F); + vtk_writer_faces.setDataOptions(gmds::N|gmds::F); + vtk_writer_faces.write("cb2_debug_blocking_faces.vtk"); + + +} + + diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h index ebd3b9059..f85f81346 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSAgent.h @@ -2,6 +2,8 @@ #define GMDS_MCTSAGENT_H #include +#include +#include /*----------------------------------------------------------------------------------------*/ namespace gmds { /*----------------------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h index 75f061783..8b276a22a 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMove.h @@ -22,7 +22,8 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMove { /** @brief Overloaded == */ virtual bool operator==(const MCTSMove& AOther) const = 0; - virtual std::string sprint() const { return "Not implemented"; } // and optionally this + virtual std::string sprint() const { return "Not implemented"; } + virtual void print() const =0; // and optionally this }; /*----------------------------------------------------------------------------*/ } diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h index aaf9bdc43..0af5a6449 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSMovePolycube.h @@ -20,14 +20,15 @@ struct LIB_GMDS_RLBLOCKING_API MCTSMovePolycube: public MCTSMove { TCellID m_AIdEdge; TCellID m_AIdBlock; double m_AParamCut; - /** @brief if typeMove=0: delete block, typeMove=1 cut block + /** @brief if typeMove=2: delete block, typeMove=1 cut block */ - bool m_typeMove; + unsigned int m_typeMove; /** @brief Overloaded == */ - MCTSMovePolycube(TCellID AIdEdge,TCellID AIdBlock, double AParamCut,bool ATypeMove); + MCTSMovePolycube(TCellID AIdEdge = -1,TCellID AIdBlock = -1 , double AParamCut = 0,unsigned int ATypeMove = -1); bool operator==(const MCTSMove& AOther) const; + void print() const; }; /*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h index 9952850fc..92301494e 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSState.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSState.h @@ -34,7 +34,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSState { /*------------------------------------------------------------------------*/ /** @brief Gives the set of actions that can be tried from the current state */ - virtual std::queue *actions_to_try() const = 0; + virtual std::deque *actions_to_try() const = 0; /*------------------------------------------------------------------------*/ /** @brief Performs the @p AMove to change of states * @param[in] AMove the movement to apply to get to a new state diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h index 354396ea3..bd4341015 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSStatePolycube.h @@ -29,7 +29,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ /*------------------------------------------------------------------------*/ /** @brief Gives the set of actions that can be tried from the current state */ - std::queue *actions_to_try() const ; + std::deque *actions_to_try() const ; /*------------------------------------------------------------------------*/ /** @brief Performs the @p AMove to change of states * @param[in] AMove the movement to apply to get to a new state @@ -69,6 +69,9 @@ class LIB_GMDS_RLBLOCKING_API MCTSStatePolycube: public MCTSState{ /** @brief return the history of the parents quality */ std::vector get_history() const; + /** @brief update the classification of a state */ + void update_class(); + private : /** @brief the curved blocking of the current state */ gmds::blocking::CurvedBlocking* m_blocking; diff --git a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h index e55407b21..e8e2328c0 100644 --- a/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h +++ b/rlBlocking/inc/gmds/rlBlocking/MCTSTree.h @@ -34,7 +34,7 @@ class LIB_GMDS_RLBLOCKING_API MCTSNode { /** @brief the parent for the current node*/ MCTSNode *parent; /** @brief queue of untried actions*/ - std::queue *untried_actions; + std::deque *untried_actions; /** @brief update the nb simulations and the score after a rollout*/ void backpropagate(double w, int n); public: @@ -72,6 +72,8 @@ class LIB_GMDS_RLBLOCKING_API MCTSNode { MCTSNode *advance_tree(const MCTSMove *m); /** @brief Return the state of the node. */ const MCTSState *get_current_state() const; + /** @brief Return the children of the node. */ + std::vector *get_children(); /** @brief Print the tree and the stats. */ void print_stats() const; /** @brief Calculate the q rate of a node. It's: wins-looses */ diff --git a/rlBlocking/src/MCTSAgent.cpp b/rlBlocking/src/MCTSAgent.cpp index 1975c7b49..01bfd526e 100644 --- a/rlBlocking/src/MCTSAgent.cpp +++ b/rlBlocking/src/MCTSAgent.cpp @@ -14,7 +14,8 @@ MCTSAgent::~MCTSAgent(){ delete tree; } /*----------------------------------------------------------------------------*/ -const MCTSMove *MCTSAgent::genmove() { +const MCTSMove *MCTSAgent::genmove() +{ // If game ended from opponent move, we can't do anything if (tree->get_current_state()->is_terminal()) { return NULL; diff --git a/rlBlocking/src/MCTSAlgorithm.cpp b/rlBlocking/src/MCTSAlgorithm.cpp index 166e142b2..4338cbb17 100644 --- a/rlBlocking/src/MCTSAlgorithm.cpp +++ b/rlBlocking/src/MCTSAlgorithm.cpp @@ -25,7 +25,7 @@ void MCTSAlgorithm::execute() MCTSState *state = new MCTSStatePolycube(this->m_geom, this->m_blocking, std::vector ()); //state->print(); // IMPORTANT: state will be garbage after advance_tree() - MCTSAgent agent(state, 1000); + MCTSAgent agent(state, 100); do { agent.feedback(); agent.genmove(); @@ -38,6 +38,8 @@ void MCTSAlgorithm::execute() // } done = new_state->is_terminal(); } while (!done); + + std::cout<<"==========================================================="< +#include /*----------------------------------------------------------------------------*/ using namespace gmds; + /*----------------------------------------------------------------------------*/ MCTSStatePolycube::MCTSStatePolycube(gmds::cad::GeomManager* AGeom, gmds::blocking::CurvedBlocking* ABlocking, std::vector hist ) - :m_geom(AGeom),m_blocking(ABlocking),m_history(hist) + :m_geom(AGeom),m_history(hist) { + m_blocking = new blocking::CurvedBlocking(*ABlocking); gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); - m_class_errors = m_class_blocking->classify(); + m_class_errors = m_class_blocking->classify(0.2); ;} /*----------------------------------------------------------------------------*/ MCTSStatePolycube::~MCTSStatePolycube() noexcept -{ delete m_class_blocking;} +{ + delete m_class_blocking; + delete m_blocking; +} /*----------------------------------------------------------------------------*/ -std::queue * +std::deque * MCTSStatePolycube::actions_to_try() const { - std::queue *Q = new std::queue(); + std::deque *Q = new std::deque(); if (m_class_errors.non_captured_points.size()== 0){ + std::cout<<"POINTS CAPT :"<get_all_id_blocks(); for(auto b : blocks){ - Q->push(new MCTSMovePolycube(NullID,b,0,0)); + Q->push_back(new MCTSMovePolycube(NullID,b,0,2)); } } else{ + std::cout<<"NB CURVES CAPT :"<< m_class_errors.non_captured_curves.size()<list_Possible_Cuts(); for(auto c : listPossibleCuts){ - Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); } } } else{ + std::cout<<"POINTS NO CAPT :"<list_Possible_Cuts(); for(auto c : listPossibleCuts){ - Q->push(new MCTSMovePolycube(c.first,NullID,c.second,1)); + Q->push_back(new MCTSMovePolycube(c.first,NullID,c.second,1)); } } + std::cout<<"LIST ACTIONS :"<print(); + } return Q; } /*----------------------------------------------------------------------------*/ MCTSState *MCTSStatePolycube::next_state(const gmds::MCTSMove *AMove) const { + std::cout<<"==================== EXECUTE ACTION ! ===================="< hist_update = get_history(); hist_update.push_back(get_quality()); - MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,this->m_blocking,hist_update); - if(m->m_typeMove == 0){ - new_state->m_blocking->remove_block(m->m_AIdBlock); + gmds::blocking::CurvedBlocking* new_b = new gmds::blocking::CurvedBlocking(*m_blocking); + MCTSStatePolycube *new_state = new MCTSStatePolycube(this->m_geom,new_b,hist_update); + if(m->m_typeMove == 2){ + //TODO ERROR, sometimes, block select not in the current blocks list...Check why !!! + std::cout<<"LIST BLOCK BLOCKING : "<get_all_id_blocks()){ + std::cout<m_AIdBlock){ + b_in_list = true; + break; + } + } + if(b_in_list){ + std::cout<<"BLOCK A DELETE :"<m_AIdBlock<m_blocking->remove_block(m->m_AIdBlock); + } + else{ + std::cout<<"BLOCK A DELETE :"<get_all_id_blocks().front()<m_blocking->remove_block(m_blocking->get_all_id_blocks().front()); + } + + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); return new_state; } else if(m->m_typeMove ==1) { new_state->m_blocking->cut_sheet(m->m_AIdEdge,m->m_AParamCut); + new_state->update_class(); + //SAVE Blocking vtk + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/cb2/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "cb2_action"+ id_act +".vtk"; + new_state->m_blocking->save_vtk_blocking(name_save_folder+name_file); return new_state; } else{ - std::cerr << "Warning: Bad type move !" << std::endl; + std::cerr << "Warning: Bad type move ! \n Type move :" << m->m_typeMove << " & ID " << m->m_AIdEdge<< std::endl; return new_state; } + std::string name_save_folder = "/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/"; + std::string id_act = std::to_string(m->m_AIdEdge); + std::string name_file = "M1_action"+ id_act +".vtk"; + m_blocking->save_vtk_blocking(name_save_folder+name_file); } /*----------------------------------------------------------------------------*/ @@ -69,6 +122,7 @@ double MCTSStatePolycube::state_rollout() const { std::cout<<"STATE ROLLOUT"< *list_action = actions_to_try(); + long long r; int a; MCTSStatePolycube *curstate = (MCTSStatePolycube *) this; // TODO: ignore const... srand(time(NULL)); bool first = true; do { - if (list_action->empty()) { - std::cerr << "Warning: Ran out of available moves and state is not terminal?"; - return 0.0; - } + std::deque *list_action = actions_to_try(); //Get first move/action //But, maybe, better to take rand move if its a delete move... MCTSMove *firstMove = list_action->front(); //TODO: implement random move when only delete moves is possible - list_action->pop(); + list_action->pop_front(); MCTSStatePolycube *old = curstate; + std::cout<<"===== SIZE UNTRIED ACTIONS : "<size()+1<<" ====="<next_state(firstMove); if (!first) { delete old; @@ -109,10 +161,10 @@ MCTSStatePolycube::state_rollout() const first = false; } while (!curstate->is_terminal()); - if(MCTSStatePolycube::result_terminal() == WIN){ + if(curstate->result_terminal() == WIN){ res=1; } - else if (MCTSStatePolycube::result_terminal() == LOSE) { + else if (curstate->result_terminal() == LOSE) { res=-1; } else{ @@ -126,17 +178,19 @@ MCTSStatePolycube::state_rollout() const MCTSStatePolycube::ROLLOUT_STATUS MCTSStatePolycube::result_terminal() const { - int max_nb_same = 3; - if (get_quality() == 0) { + if (m_class_errors.non_captured_points.empty() && m_class_errors.non_captured_curves.empty() && m_class_errors.non_captured_surfaces.empty()) { return WIN; } - else if (check_nb_same_quality() >= max_nb_same){ + else if (check_nb_same_quality() >= 3){ return DRAW; } - else if (m_history.back() < get_quality()){ + else if (!m_history.empty() && m_history.back() < this->get_quality()){ + return LOSE; + } + else if (this->actions_to_try()->empty()){ return LOSE; } - std::cerr << "ERROR: NOT terminal state !" << std::endl; + std::cerr << "ERROR: NOT terminal state ..." << std::endl; return DRAW; } /*----------------------------------------------------------------------------*/ @@ -164,7 +218,10 @@ MCTSStatePolycube::is_terminal() const else if(check_nb_same_quality() >= 3){ return true; } - else if(!m_history.empty() && m_history.back() < get_quality()){ + else if(!m_history.empty() && m_history.back() < this->get_quality()){ + return true; + } + else if(this->actions_to_try()->empty()){ return true; } else { @@ -178,7 +235,7 @@ double { return m_class_errors.non_captured_points.size() * 0.8 + m_class_errors.non_captured_curves.size() * 0.6 + m_class_errors.non_captured_surfaces.size() * 0.4; -} + } /*----------------------------------------------------------------------------*/ gmds::cad::GeomManager* MCTSStatePolycube::get_geom(){ return m_blocking->geom_model(); @@ -206,3 +263,10 @@ std::vector MCTSStatePolycube::get_history() const return m_history; } /*----------------------------------------------------------------------------*/ +void MCTSStatePolycube::update_class() +{ + gmds::blocking::CurvedBlockingClassifier classifier(m_blocking); + m_class_blocking = new blocking::CurvedBlockingClassifier(classifier); + m_class_errors = m_class_blocking->classify(0.2); +} +/*----------------------------------------------------------------------------*/ diff --git a/rlBlocking/src/MCTSTree.cpp b/rlBlocking/src/MCTSTree.cpp index dfb6f6d80..24044c6a4 100644 --- a/rlBlocking/src/MCTSTree.cpp +++ b/rlBlocking/src/MCTSTree.cpp @@ -28,7 +28,7 @@ MCTSNode::~MCTSNode() { delete children; while (!untried_actions->empty()) { delete untried_actions->front(); // if a move is here then it is not a part of a child node and needs to be deleted here - untried_actions->pop(); + untried_actions->pop_front(); } delete untried_actions; } @@ -44,7 +44,7 @@ void MCTSNode::expand() { } // get next untried action MCTSMove *next_move = untried_actions->front(); // get value - untried_actions->pop(); // remove it + untried_actions->pop_front(); // remove it MCTSState *next_state = state->next_state(next_move); if(state->get_quality() == next_state->get_quality()){ @@ -66,6 +66,12 @@ const MCTSState *MCTSNode::get_current_state() const return state; } /*----------------------------------------------------------------------------*/ +std::vector +*MCTSNode::get_children() +{ + return children; +} +/*----------------------------------------------------------------------------*/ bool MCTSNode::is_terminal() const { @@ -89,7 +95,9 @@ unsigned int MCTSNode::get_size() const { /*----------------------------------------------------------------------------*/ MCTSNode *MCTSNode::select_best_child(double c) const { /** selects best child based on the winrate of whose turn it is to play */ - if (children->empty()) return NULL; + if (children->empty()) { + return NULL; + } else if (children->size() == 1) return children->at(0); else { double uct, max = -1; @@ -166,11 +174,25 @@ void MCTSNode::print_stats() const { << "Tree size: " << size << std::endl << "Number of simulations: " << number_of_simulations << std::endl << "Branching factor at root: " << children->size() << std::endl; - // print TOPK of them along with their winrates -// std::cout << "Best moves:" << std::endl; -// for (int i = 0 ; i < children->size() && i < TOPK ; i++) { -// std::cout << " " << i + 1 << ". " << children->at(i)->move->sprint() << " --> " -// << std::setprecision(4) << 100.0 * children->at(i)->calculate_winrate(state->player1_turn()) << "%" << endl; + // Print the best move for a current node +// MCTSNode *bestChild; +// bool first = true; +// double winRateChild = 0; +// if(!children->empty()) { +// for (int i = 0; i < children->size(); i++) { +// if (first) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// first = false; +// } +// +// else if (winRateChild < children->at(i)->calculate_winrate()) { +// bestChild = children->at(i); +// winRateChild = bestChild->calculate_winrate(); +// } +// } +// std::cout << "Best Move :" << std::endl; +// bestChild->move->print(); // } std::cout << "________________________________" << std::endl; } diff --git a/rlBlocking/tst/MCTSTestSuite.h b/rlBlocking/tst/MCTSTestSuite.h index de5bce2b8..3d40cb22d 100644 --- a/rlBlocking/tst/MCTSTestSuite.h +++ b/rlBlocking/tst/MCTSTestSuite.h @@ -37,23 +37,20 @@ TEST(MCTSTestSuite, testExAglo) { gmds::cad::FACManager geom_model; - set_up_MCTS(&geom_model,"M1.vtk"); + set_up_MCTS(&geom_model,"cb2.vtk"); gmds::blocking::CurvedBlocking bl(&geom_model,true); - bl.save_vtk_blocking("/home/bourmaudp/Documents/PROJETS/gmds/gmds_Correction_Class_Dev/saveResults/M1_init_blocking.vtk"); - std::cout<<"NB points : "<< geom_model.getPoints().size()<execute(); + + std::cout<<"==================== END TEST ! ===================="<