Skip to content

Commit

Permalink
Merge pull request #52 from LIHPC-Computational-Geometry/45-RL_Actor_…
Browse files Browse the repository at this point in the history
…Critic

45 RL environment and model
  • Loading branch information
ArzhelaR authored Aug 27, 2024
2 parents f66cba8 + d03f6d8 commit f904e1f
Show file tree
Hide file tree
Showing 21 changed files with 1,047 additions and 77 deletions.
9 changes: 5 additions & 4 deletions actions/triangular_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from model.mesh_struct.mesh import Mesh
from model.mesh_struct.mesh_elements import Dart, Node
from model.mesh_analysis import degree
from model.mesh_analysis import degree, isFlipOk


def flip_edge_ids(mesh: Mesh, id1: int, id2: int) -> True:
Expand All @@ -11,7 +11,8 @@ def flip_edge_ids(mesh: Mesh, id1: int, id2: int) -> True:

def flip_edge(mesh: Mesh, n1: Node, n2: Node) -> True:
found, d = mesh.find_inner_edge(n1, n2)
if not found:

if not found or not isFlipOk(d):
return False

d2, d1, d11, d21, d211, n1, n2, n3, n4 = active_triangles(mesh, d)
Expand Down Expand Up @@ -113,5 +114,5 @@ def test_degree(n: Node) -> bool:
"""
if degree(n) > 10:
return False


else:
return True
Empty file added environment/__init__.py
Empty file.
121 changes: 121 additions & 0 deletions environment/trimesh_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Any
import numpy as np
from model.mesh_analysis import global_score, isValidAction, find_template_opposite_node
from model.mesh_struct.mesh_elements import Dart
from model.mesh_struct.mesh import Mesh
from actions.triangular_actions import flip_edge
from model.random_trimesh import random_flip_mesh

# possible actions
FLIP = 0
GLOBAL = 0


class TriMesh:
def __init__(self, mesh=None, mesh_size: int = None, max_steps: int = 50, feat: int = 0):
self.mesh = mesh if mesh is not None else random_flip_mesh(mesh_size)
self.mesh_size = len(self.mesh.nodes)
self.size = len(self.mesh.dart_info)
self.actions = np.array([FLIP])
self.reward = 0
self.steps = 0
self.max_steps = max_steps
self.nodes_scores = global_score(self.mesh)[0]
self.ideal_score = global_score(self.mesh)[2]
self.terminal = False
self.feat = feat
self.won = 0

def reset(self, mesh=None):
self.reward = 0
self.steps = 0
self.terminal = False
self.mesh = mesh if mesh is not None else random_flip_mesh(self.mesh_size)
self.size = len(self.mesh.dart_info)
self.nodes_scores = global_score(self.mesh)[0]
self.ideal_score = global_score(self.mesh)[2]
self.won = 0

def step(self, action):
dart_id = action[1]
_, mesh_score, mesh_ideal_score = global_score(self.mesh)
d = Dart(self.mesh, dart_id)
d1 = d.get_beta(1)
n1 = d.get_node()
n2 = d1.get_node()
flip_edge(self.mesh, n1, n2)
self.steps += 1
next_nodes_score, next_mesh_score, _ = global_score(self.mesh)
self.nodes_scores = next_nodes_score
self.reward = (mesh_score - next_mesh_score)*10
if self.steps >= self.max_steps or next_mesh_score == mesh_ideal_score:
if next_mesh_score == mesh_ideal_score:
self.won = True
self.terminal = True

def get_x(self, s: Mesh, a: int) -> tuple[Any, list[int | list[int]]]:
"""
Get the feature vector of the state-action pair
:param s: the state
:param a: the action
:return: the feature vector and valid darts id
"""
if s is None:
s = self.mesh
if self.feat == GLOBAL:
return get_x_global_4(self, s)


def get_x_global_4(env, state: Mesh) -> tuple[Any, list[int | list[int]]]:
"""
Get the feature vector of the state.
:param state: the state
:param env: The environment
:return: the feature vector
"""
mesh = state
nodes_scores = global_score(mesh)[0]
size = len(mesh.dart_info)
template = np.zeros((size, 6))

for d_info in mesh.dart_info:

d = Dart(mesh, d_info[0])
A = d.get_node()
d1 = d.get_beta(1)
B = d1.get_node()
d11 = d1.get_beta(1)
C = d11.get_node()

#Template niveau 1
template[d_info[0], 0] = nodes_scores[C.id]
template[d_info[0], 1] = nodes_scores[A.id]
template[d_info[0], 2] = nodes_scores[B.id]

#template niveau 2

n_id = find_template_opposite_node(d)
if n_id is not None:
template[d_info[0], 3] = nodes_scores[n_id]
n_id = find_template_opposite_node(d1)
if n_id is not None:
template[d_info[0], 4] = nodes_scores[n_id]
n_id = find_template_opposite_node(d11)
if n_id is not None:
template[d_info[0], 5] = nodes_scores[n_id]

dart_to_delete = []
dart_ids = []
for i in range(size):
d = Dart(mesh, i)
if not isValidAction(mesh, d.id):
dart_to_delete.append(i)
else :
dart_ids.append(i)
valid_template = np.delete(template, dart_to_delete, axis=0)
score_sum = np.sum(np.abs(valid_template), axis=1)
indices_top_10 = np.argsort(score_sum)[-5:][::-1]
valid_dart_ids = [dart_ids[i] for i in indices_top_10]
X = valid_template[indices_top_10, :]
X = X.flatten()
return X, valid_dart_ids
53 changes: 6 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,13 @@
from view.window import Game
from model.mesh_struct.mesh import Mesh
from mesh_display import MeshDisplay
import model.random_trimesh as TM
#import model.reader as Reader

import sys
import json

from user_game import user_game
from train import train


# Press the green button in the gutter to run the script.
if __name__ == '__main__':

if len(sys.argv) != 2:
print("Usage: main.py <nb_nodes_of_the_mesh>")
else:
cmap = TM.random_mesh(int(sys.argv[1]))
mesh_disp = MeshDisplay(cmap)
g = Game(cmap, mesh_disp)
g.run()

"""
#Code to load a json file and create a mesh
if len(sys.argv) != 2:
print("Usage: main.py <mesh_file.json>")
else:
f = open(sys.argv[1])
json_mesh = json.load(f)
cmap = Mesh(json_mesh['nodes'], json_mesh['faces'])
mesh_disp = MeshDisplay(cmap)
g = Game(cmap, mesh_disp)
g.run()
"""
"""
# Code to load a Medit .mesh file and create a mesh
if len(sys.argv) != 2:
print("Usage: main.py <mesh_file.mesh>")
else:
cmap = Reader.read_medit(sys.argv[1])
mesh_disp = MeshDisplay(cmap)
g = Game(cmap, mesh_disp)
g.run()
"""
"""
# Code to load a gmsh .msh file and create a mesh
if len(sys.argv) != 2:
print("Usage: main.py <mesh_file.msh>")
if len(sys.argv) == 2:
user_game(int(sys.argv[1]))
else:
cmap = Reader.read_gmsh(sys.argv[1])
mesh_disp = MeshDisplay(cmap)
g = Game(cmap, mesh_disp)
g.run()
"""
train()
84 changes: 80 additions & 4 deletions model/mesh_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import sqrt, degrees, radians, cos, sin
from math import sqrt, degrees, radians, cos, sin, acos
import numpy as np

from model.mesh_struct.mesh_elements import Dart, Node
Expand Down Expand Up @@ -113,9 +113,9 @@ def adjacent_darts(n: Node) -> list[Dart]:
d = Dart(n.mesh, d_info[0])
d_nfrom = d.get_node()
d_nto = d.get_beta(1)
if d_nfrom == n:
if d_nfrom == n and d not in adj_darts:
adj_darts.append(d)
if d_nto.get_node() == n:
if d_nto.get_node() == n and d not in adj_darts:
adj_darts.append(d)
return adj_darts

Expand All @@ -137,7 +137,8 @@ def degree(n: Node) -> int:
boundary_darts.append(d)
else:
adjacency += 0.5

if adjacency != int(adjacency):
raise ValueError("Adjacency error")
return adjacency


Expand Down Expand Up @@ -191,3 +192,78 @@ def find_opposite_node(d: Dart) -> (int, int):
y_C = A.y() + y_AC

return x_C, y_C

def find_template_opposite_node(d: Dart) -> (int):
"""
Find the the vertex opposite in the adjacent triangle
:param d: a dart
:return: the node found
"""

d2 = d.get_beta(2)
if d2 is not None:
d21 = d2.get_beta(1)
d211 = d21.get_beta(1)
node_opposite = d211.get_node()
return node_opposite.id
else:
return None


def node_in_mesh(mesh: Mesh, x: float, y: float) -> (bool, int):
"""
Search if the node of coordinate (x, y) is inside the mesh.
:param mesh: the mesh to work with
:param x: X coordinate
:param y: Y coordinate
:return: a boolean indicating if the node is inside the mesh and the id of the node if it is.
"""
n_id = 0
for n in mesh.nodes:
if abs(x - n[0]) <= 0.1 and abs(y - n[1]) <= 0.1:
return True, n_id
n_id = n_id + 1
return False, None


def isValidAction(mesh, dart_id: int) -> bool:
d = Dart(mesh, dart_id)
boundary_darts = get_boundary_darts(mesh)
if d in boundary_darts or not isFlipOk(d):
return False
else:
return True

def get_angle_by_coord(x1: float, y1: float, x2: float, y2: float, x3:float, y3:float) -> float:
BAx, BAy = x1 - x2, y1 - y2
BCx, BCy = x3 - x2, y3 - y2

cos_ABC = (BAx * BCx + BAy * BCy) / (sqrt(BAx ** 2 + BAy ** 2) * sqrt(BCx ** 2 + BCy ** 2))

rad = acos(cos_ABC)
deg = degrees(rad)
return deg


def isFlipOk(d:Dart) -> bool:
d1 = d.get_beta(1)
d11 = d1.get_beta(1)
A = d.get_node()
B = d1.get_node()
C = d11.get_node()
d2 = d.get_beta(2)
if d2 is None:
return False
else:
d21 = d2.get_beta(1)
d211 = d21.get_beta(1)
D = d211.get_node()

# Calcul angle at d limits
angle_B = get_angle_by_coord(A.x(), A.y(), B.x(), B.y(), C.x(), C.y()) + get_angle_by_coord(A.x(), A.y(), B.x(), B.y(), D.x(), D.y())
angle_A = get_angle_by_coord(B.x(), B.y(), A.x(), A.y(), C.x(), C.y()) + get_angle_by_coord(B.x(), B.y(), A.x(), A.y(), D.x(), D.y())

if angle_B >= 180 or angle_A >= 180:
return False
else:
return True
Loading

0 comments on commit f904e1f

Please sign in to comment.