diff --git a/board/data/handler/internal/board.py b/board/data/handler/internal/board.py index 87813c3..f348744 100644 --- a/board/data/handler/internal/board.py +++ b/board/data/handler/internal/board.py @@ -1,4 +1,6 @@ +import random from board.data import Point, Section, Tile, Tiles +from cursor.data import Color def init_first_section() -> dict[int, dict[int, Section]]: @@ -18,65 +20,278 @@ class BoardHandler: # sections[y][x] sections: dict[int, dict[int, Section]] = init_first_section() + # 맵의 각 끝단 섹션 위치 + max_x: int = 0 + min_x: int = 0 + max_y: int = 0 + min_y: int = 0 + @staticmethod def fetch(start: Point, end: Point) -> Tiles: # 반환할 데이터 공간 미리 할당 out_width, out_height = (end.x - start.x + 1), (start.y - end.y + 1) out = bytearray(out_width * out_height) + # TODO: 새로운 섹션과의 관계로 경계값이 바뀔 수 있음. + # 이를 fetch 결과에 적용시킬 수 있도록 미리 다 만들어놓고 fetch를 시작해야 함. + # 현재는 섹션이 메모리 내부 레퍼런스로 저장되기 때문에 이렇게 미리 받아놓고 할 수 있음. + # 나중에는 다시 섹션을 가져와야 함. + sections = [] for sec_y in range(start.y // Section.LENGTH, end.y // Section.LENGTH - 1, - 1): for sec_x in range(start.x // Section.LENGTH, end.x // Section.LENGTH + 1): section = BoardHandler._get_or_create_section(sec_x, sec_y) + sections.append(section) - inner_start = Point( - x=max(start.x, section.abs_x) - (section.abs_x), - y=min(start.y, section.abs_y + Section.LENGTH-1) - section.abs_y - ) - inner_end = Point( - x=min(end.x, section.abs_x + Section.LENGTH-1) - section.abs_x, - y=max(end.y, section.abs_y) - section.abs_y - ) + for section in sections: + inner_start = Point( + x=max(start.x, section.abs_x) - (section.abs_x), + y=min(start.y, section.abs_y + Section.LENGTH-1) - section.abs_y + ) + inner_end = Point( + x=min(end.x, section.abs_x + Section.LENGTH-1) - section.abs_x, + y=max(end.y, section.abs_y) - section.abs_y + ) - fetched = section.fetch(start=inner_start, end=inner_end) + fetched = section.fetch(start=inner_start, end=inner_end) - x_gap, y_gap = (inner_end.x - inner_start.x + 1), (inner_start.y - inner_end.y + 1) + x_gap, y_gap = (inner_end.x - inner_start.x + 1), (inner_start.y - inner_end.y + 1) - # start로부터 떨어진 거리 - out_x = (section.abs_x + inner_start.x) - start.x - out_y = start.y - (section.abs_y + inner_start.y) + # start로부터 떨어진 거리 + out_x = (section.abs_x + inner_start.x) - start.x + out_y = start.y - (section.abs_y + inner_start.y) - for row_num in range(y_gap): - out_idx = (out_width * (out_y + row_num)) + out_x - data_idx = row_num * x_gap + for row_num in range(y_gap): + out_idx = (out_width * (out_y + row_num)) + out_x + data_idx = row_num * x_gap - data = fetched.data[data_idx:data_idx+x_gap] - out[out_idx:out_idx+x_gap] = data + data = fetched.data[data_idx:data_idx+x_gap] + out[out_idx:out_idx+x_gap] = data return Tiles(data=out) @staticmethod - def update_tile(p: Point, tile: Tile): - tiles = Tiles(data=bytearray([tile.data])) + def open_tile(p: Point) -> Tile: + section, inner_p = BoardHandler._get_section_from_abs_point(p) + + tiles = section.fetch(inner_p) + + tile = Tile.from_int(tiles.data[0]) + tile.is_open = True + + tiles.data[0] = tile.data + + section.update(data=tiles, start=inner_p) + BoardHandler._save_section(section) + + return tile + + @staticmethod + def open_tiles_cascade(p: Point) -> tuple[Point, Point, Tiles]: + """ + 지정된 타일부터 주변 타일들을 연쇄적으로 개방한다. + 빈칸들과 빈칸과 인접한숫자 타일까지 개방하며, 섹션 가장자리 데이터가 새로운 섹션으로 인해 중간에 수정되는 것을 방지하기 위해 + 섹션을 사용할 때 인접 섹션이 존재하지 않으면 미리 만들어 놓는다. + """ + # 탐색하며 발견한 섹션들 + sections: list[Section] = [] + + def fetch_section(sec_p: Point) -> Section: + # 가져오는 데이터의 일관성을 위해 주변 섹션을 미리 만들어놓기 + delta = [ + (0, 1), (0, -1), (-1, 0), (1, 0), # 상하좌우 + (-1, 1), (1, 1), (-1, -1), (1, -1), # 좌상 우상 좌하 우하 + ] + for dx, dy in delta: + new_p = Point(x=sec_p.x+dx, y=sec_p.y+dy) + _ = BoardHandler._get_or_create_section(new_p.x, new_p.y) + + new_section = BoardHandler._get_or_create_section(sec_p.x, sec_p.y) + return new_section + + def get_section(p: Point) -> tuple[Section, Point]: + sec_p = Point( + x=p.x // Section.LENGTH, + y=p.y // Section.LENGTH + ) + + section = None + for sec in sections: + # 이미 가지고 있으면 반환 + if sec.p == sec_p: + section = sec + break + + # 새로 가져오기 + if section is None: + section = fetch_section(sec_p) + sections.append(section) + + inner_p = Point( + x=p.x - section.abs_x, + y=p.y - section.abs_y + ) + + return section, inner_p + + queue = [] + queue.append(p) + + visited = set() + visited.add((p.x, p.y)) + + # 추후 fetch 범위 + min_x, min_y = p.x, p.y + max_x, max_y = p.x, p.y + + while len(queue) > 0: + p = queue.pop(0) + + # 범위 업데이트 + min_x, min_y = min(min_x, p.x), min(min_y, p.y) + max_x, max_y = max(max_x, p.x), max(max_y, p.y) + + sec, inner_p = get_section(p) + + # TODO: section.fetch_one(point) 같은거 만들어야 할 듯 + tile = Tile.from_int(sec.fetch(inner_p).data[0]) + + # 타일 열어주기 + tile.is_open = True + tile.is_flag = False + tile.color = None + + sec.update(Tiles(data=bytearray([tile.data])), inner_p) + + if tile.number is not None: + # 빈 타일 주변 number까지만 열어야 함. + continue + + # (x, y) 순서 + delta = [ + (0, 1), (0, -1), (-1, 0), (1, 0), # 상하좌우 + (-1, 1), (1, 1), (-1, -1), (1, -1), # 좌상 우상 좌하 우하 + ] + + # 큐에 추가될 포인트 리스트 + temp_list = [] + + for dx, dy in delta: + np = Point(x=p.x+dx, y=p.y+dy) + + if (np.x, np.y) in visited: + continue + visited.add((np.x, np.y)) + + sec, inner_p = get_section(np) + + nearby_tile = Tile.from_int(sec.fetch(inner_p).data[0]) + if nearby_tile.is_open: + # 이미 연 타일, 혹은 이전에 존재하던 열린 number 타일 + continue + + temp_list.append(np) + + queue.extend(temp_list) + + # 섹션 변경사항 모두 저장 + for section in sections: + BoardHandler._save_section(section) + + start_p = Point(min_x, max_y) + end_p = Point(max_x, min_y) + tiles = BoardHandler.fetch(start_p, end_p) + + return start_p, end_p, tiles + + @staticmethod + def set_flag_state(p: Point, state: bool, color: Color | None = None) -> Tile: + section, inner_p = BoardHandler._get_section_from_abs_point(p) + + tiles = section.fetch(inner_p) + + tile = Tile.from_int(tiles.data[0]) + tile.is_flag = state + tile.color = color - sec_p = Point(x=p.x // Section.LENGTH, y=p.y // Section.LENGTH) - section = BoardHandler.sections[sec_p.y][sec_p.x] + tiles.data[0] = tile.data + + section.update(data=tiles, start=inner_p) + BoardHandler._save_section(section) + + return tile + + def _get_section_from_abs_point(abs_p: Point) -> tuple[Section, Point]: + """ + 절대 좌표 abs_p를 포함하는 섹션, 그리고 abs_p의 섹션 내부 좌표를 반환한다. + """ + sec_p = Point( + x=abs_p.x // Section.LENGTH, + y=abs_p.y // Section.LENGTH + ) + + section = BoardHandler._get_or_create_section(sec_p.x, sec_p.y) inner_p = Point( - x=p.x - section.abs_x, - y=p.y - section.abs_y + x=abs_p.x - section.abs_x, + y=abs_p.y - section.abs_y ) - section.update(data=tiles, start=inner_p) + return section, inner_p + + def _save_section(section: Section): + BoardHandler.sections[section.p.y][section.p.x] = section + + @staticmethod + def get_random_open_position() -> Point: + """ + 전체 맵에서 랜덤한 열린 타일 위치를 하나 찾는다. + 섹션이 하나 이상 존재해야한다. + """ + # 이미 방문한 섹션들 + visited = set() - # 지금은 안 해도 되긴 할텐데 일단 해 놓기 - BoardHandler.sections[sec_p.y][sec_p.x] = section + sec_x_range = (BoardHandler.min_x, BoardHandler.max_x) + sec_y_range = (BoardHandler.min_y, BoardHandler.max_y) + + while True: + rand_p = Point( + x=random.randint(sec_x_range[0], sec_x_range[1]), + y=random.randint(sec_y_range[0], sec_y_range[1]) + ) + + if (rand_p.x, rand_p.y) in visited: + continue + + visited.add((rand_p.x, rand_p.y)) + + chosen_section = BoardHandler._get_section_or_none(rand_p.x, rand_p.y) + if chosen_section is None: + continue + + # 섹션 내부의 랜덤한 열린 타일 위치를 찾는다. + inner_point = randomly_find_open_tile(chosen_section) + if inner_point is None: + continue + + open_point = Point( + x=chosen_section.abs_x + inner_point.x, + y=chosen_section.abs_y + inner_point.y + ) + + return open_point @staticmethod def _get_or_create_section(x: int, y: int) -> Section: if y not in BoardHandler.sections: + BoardHandler.max_y = max(BoardHandler.max_y, y) + BoardHandler.min_y = min(BoardHandler.min_y, y) + BoardHandler.sections[y] = {} if x not in BoardHandler.sections[y]: + BoardHandler.max_x = max(BoardHandler.max_x, x) + BoardHandler.min_x = min(BoardHandler.min_x, x) + new_section = Section.create(Point(x, y)) # (x, y) @@ -110,3 +325,51 @@ def _get_or_create_section(x: int, y: int) -> Section: def _get_section_or_none(x: int, y: int) -> Section | None: if y in BoardHandler.sections and x in BoardHandler.sections[y]: return BoardHandler.sections[y][x] + + +def randomly_find_open_tile(section: Section) -> Point | None: + """ + 섹션 안에서 랜덤한 열린 타일 위치를 찾는다. + 시작 위치, 순회 방향의 순서를 무작위로 잡아 탐색한다. + 만약 열린 타일이 존재하지 않는다면 None. + """ + + # (증감값, 한계값) + directions = [ + (1, Section.LENGTH - 1), (-1, 0) # 순방향, 역방향 + ] + random.shuffle(directions) + + x_start = random.randint(0, Section.LENGTH - 1) + y_start = random.randint(0, Section.LENGTH - 1) + + pointers = [0, 0] # first, second + start_values = [0, 0] + + x_first = random.choice([True, False]) + x_pointer = 0 if x_first else 1 + y_pointer = 1 if x_first else 0 + + start_values[x_pointer] = x_start + start_values[y_pointer] = y_start + + # second 양방향 탐색 + for num, limit in directions: + for second in range(start_values[1], limit + num, num): + pointers[1] = second + + # first 양방향 탐색 + for num, limit in directions: + for first in range(start_values[0], limit + num, num): + pointers[0] = first + + x = pointers[x_pointer] + y = pointers[y_pointer] + + idx = y * Section.LENGTH + x + + tile = Tile.from_int(section.data[idx]) + if tile.is_open: + # 좌표계에 맞게 y 반전 + y = Section.LENGTH - y - 1 + return Point(x, y) diff --git a/board/data/handler/test/board_test.py b/board/data/handler/test/board_test.py index 429885d..fa7ff86 100644 --- a/board/data/handler/test/board_test.py +++ b/board/data/handler/test/board_test.py @@ -1,7 +1,9 @@ import unittest +from unittest.mock import patch, MagicMock from tests.utils import cases -from board.data import Point, Tile +from board.data import Point, Tile, Section from board.data.handler import BoardHandler +from cursor.data import Color from .fixtures import setup_board FETCH_CASE = \ @@ -44,15 +46,98 @@ def test_fetch(self, data, expect): self.assertEqual(data, expect) - def test_update_tile(self): - p = Point(-1, -1) + def test_open_tile(self): + p = Point(0, -2) - tile = Tile.from_int(0) - BoardHandler.update_tile(p=p, tile=tile) + result = BoardHandler.open_tile(p) tiles = BoardHandler.fetch(start=p, end=p) + tile = Tile.from_int(tiles.data[0]) - self.assertEqual(tiles.data[0], tile.data) + self.assertTrue(tile.is_open) + self.assertEqual(tile, result) + + @patch("board.data.Section.create") + def test_open_tiles_cascade(self, create_seciton_mock: MagicMock): + def stub_section_create(p: Point) -> Section: + return Section( + data=bytearray([0b10000000 for _ in range(Section.LENGTH ** 2)]), + p=p + ) + create_seciton_mock.side_effect = stub_section_create + + p = Point(0, 3) + + start_p, end_p, tiles = BoardHandler.open_tiles_cascade(p) + + self.assertEqual(len(create_seciton_mock.mock_calls), 20) + + self.assertEqual(start_p, Point(-1, 3)) + self.assertEqual(end_p, Point(3, -1)) + self.assertEqual(tiles, BoardHandler.fetch(start=start_p, end=end_p)) + + OPEN_0 = 0b10000000 + OPEN_1 = 0b10000001 + CLOSED_1 = 0b00000001 + BLUE_FLAG = 0b01110000 + PURPLE_FLAG = 0b00111001 + + expected = bytearray([ + OPEN_1, OPEN_0, OPEN_0, OPEN_0, OPEN_0, + OPEN_1, OPEN_1, OPEN_1, OPEN_1, OPEN_0, + OPEN_1, OPEN_1, BLUE_FLAG, OPEN_1, OPEN_0, + OPEN_0, OPEN_1, CLOSED_1, OPEN_1, OPEN_0, + OPEN_1, OPEN_1, PURPLE_FLAG, OPEN_1, OPEN_1 + ]) + self.assertEqual(tiles.data, expected) + + def test_set_flag_state_true(self): + p = Point(0, -2) + color = Color.BLUE + + result = BoardHandler.set_flag_state(p=p, state=True, color=color) + + tiles = BoardHandler.fetch(start=p, end=p) + tile = Tile.from_int(tiles.data[0]) + + self.assertTrue(tile.is_flag) + self.assertEqual(tile.color, color) + + self.assertEqual(tile, result) + + def test_set_flag_state_false(self): + p = Point(1, -1) + + result = BoardHandler.set_flag_state(p=p, state=False) + + tiles = BoardHandler.fetch(start=p, end=p) + tile = Tile.from_int(tiles.data[0]) + + self.assertFalse(tile.is_flag) + self.assertIsNone(tile.color) + + self.assertEqual(tile, result) + + def test_get_random_open_position(self): + for _ in range(10): + point = BoardHandler.get_random_open_position() + + tiles = BoardHandler.fetch(point, point) + tile = Tile.from_int(tiles.data[0]) + + self.assertTrue(tile.is_open) + + def test_get_random_open_position_one_section_one_open(self): + sec = BoardHandler.sections[-1][0] + BoardHandler.sections = {-1: {0: sec}} + + for _ in range(10): + point = BoardHandler.get_random_open_position() + + tiles = BoardHandler.fetch(point, point) + tile = Tile.from_int(tiles.data[0]) + + self.assertTrue(tile.is_open) if __name__ == "__main__": diff --git a/board/data/handler/test/fixtures.py b/board/data/handler/test/fixtures.py index e4f6ee9..88a11ec 100644 --- a/board/data/handler/test/fixtures.py +++ b/board/data/handler/test/fixtures.py @@ -45,3 +45,7 @@ def setup_board(): -1: Section(Point(-1, -1), tile_state_3) } } + BoardHandler.max_x = 0 + BoardHandler.min_x = -1 + BoardHandler.max_y = 0 + BoardHandler.min_y = -1 diff --git a/board/data/internal/section.py b/board/data/internal/section.py index 061a5a4..40d5e9e 100644 --- a/board/data/internal/section.py +++ b/board/data/internal/section.py @@ -77,14 +77,14 @@ def apply_neighbor_diagonal(self, neighbor): self_idx = (self_y * Section.LENGTH) + self_x neighbor_idx = (neighbor_y * Section.LENGTH) + neighbor_x - if self.data[self_idx] == MINE_TILE: + if self.data[self_idx] & MINE_TILE: affect_origin_mines_to_new( new_tiles=neighbor.data, x_range=(neighbor_x, neighbor_x), y_range=(neighbor_y, neighbor_y) ) - if neighbor.data[neighbor_idx] == MINE_TILE: + if neighbor.data[neighbor_idx] & MINE_TILE: affect_new_mines_to_origin( origin_tiles=self.data, new_tiles=neighbor.data, @@ -113,14 +113,14 @@ def apply_neighbor_vertical(self, neighbor): leftmost = max(0, x - 1) rightmost = min(x + 1, Section.LENGTH - 1) - if self.data[self_idx] == MINE_TILE: + if self.data[self_idx] & MINE_TILE: affect_origin_mines_to_new( new_tiles=neighbor.data, x_range=(leftmost, rightmost), y_range=(neighbor_y, neighbor_y) ) - if neighbor.data[neighbor_idx] == MINE_TILE: + if neighbor.data[neighbor_idx] & MINE_TILE: affect_new_mines_to_origin( origin_tiles=self.data, new_tiles=neighbor.data, @@ -149,14 +149,14 @@ def apply_neighbor_horizontal(self, neighbor): top = min(y + 1, Section.LENGTH - 1) bottom = max(0, y - 1) - if self.data[self_idx] == MINE_TILE: + if self.data[self_idx] & MINE_TILE: affect_origin_mines_to_new( new_tiles=neighbor.data, x_range=(neighbor_x, neighbor_x), y_range=(bottom, top) ) - if neighbor.data[neighbor_idx] == MINE_TILE: + if neighbor.data[neighbor_idx] & MINE_TILE: affect_new_mines_to_origin( origin_tiles=self.data, new_tiles=neighbor.data, @@ -192,7 +192,7 @@ def create(p: Point): cur_tile = data[rand_idx] # 이미 지뢰가 존재 - if cur_tile == MINE_TILE: + if cur_tile & MINE_TILE: continue # 주변 타일 검사 @@ -222,7 +222,7 @@ def affect_origin_mines_to_new(new_tiles: bytearray, x_range: tuple[int, int], y idx = (y * Section.LENGTH) + x tile = new_tiles[idx] - if tile == MINE_TILE: + if tile & MINE_TILE: continue num = tile & NUM_MASK @@ -250,7 +250,7 @@ def affect_new_mines_to_origin( idx = (y * Section.LENGTH) + x tile = origin_tiles[idx] - if tile == MINE_TILE: + if tile & MINE_TILE: continue num = tile & NUM_MASK @@ -266,7 +266,7 @@ def affect_new_mines_to_origin( new_tiles[new_tiles_idx] = cnt continue - origin_tiles[idx] = num + 1 + origin_tiles[idx] = tile + 1 def decrease_number_around_and_count_mines(tiles: bytearray, p: Point) -> int: @@ -278,7 +278,7 @@ def decrease_number_around_and_count_mines(tiles: bytearray, p: Point) -> int: def do(t: int, p: Point) -> tuple[int | None, bool]: nonlocal cnt - if t == MINE_TILE: + if t & MINE_TILE: cnt += 1 return None, False @@ -297,7 +297,7 @@ def remove_one_nearby_mine(tiles: bytearray, p: Point): 그 주변 타일의 num은 1씩 감소한다. """ def do(t: int, p: Point) -> tuple[int | None, bool]: - if t != MINE_TILE: + if not (t & MINE_TILE): return None, False cnt = decrease_number_around_and_count_mines(tiles=tiles, p=p) @@ -311,7 +311,7 @@ def increase_number_around(tiles: bytearray, p: Point): 주변 타일의 num을 1씩 증가시킨다. """ def do(t: int, p: Point) -> tuple[int | None, bool]: - if t == MINE_TILE: + if t & MINE_TILE: return None, False t += 1 diff --git a/board/data/test/section_test.py b/board/data/test/section_test.py index 55ca98e..00700e6 100644 --- a/board/data/test/section_test.py +++ b/board/data/test/section_test.py @@ -166,7 +166,7 @@ def setUp(self): # 왼쪽 위 섹션: 오른쪽 끝을 감싸는 지뢰들 self.left_top_section = Section(Point(-1, 1), data=bytearray([ MINES_OF(0), MINES_OF(1), MINES_OF(2), MINES_OF(2), - MINES_OF(0), MINES_OF(2), MINE_TILE__, MINE_TILE__, + MINES_OF(0), MINES_OF(2), MINE_TILE__, OPEN_MINE__, MINES_OF(0), MINES_OF(3), MINE_TILE__, MINES_OF(5), MINES_OF(0), MINES_OF(2), MINE_TILE__, MINE_TILE__ ])) @@ -174,19 +174,19 @@ def setUp(self): self.right_top_section = Section(Point(0, 1), data=bytearray([ MINES_OF(1), MINES_OF(1), MINES_OF(0), MINES_OF(0), MINE_TILE__, MINES_OF(2), MINES_OF(0), MINES_OF(0), - MINE_TILE__, MINES_OF(4), MINES_OF(1), MINES_OF(0), - MINE_TILE__, MINE_TILE__, MINES_OF(1), MINES_OF(0) + OPEN_MINE__, MINES_OF(4), MINES_OF(1), MINES_OF(0), + OPEN_MINE__, MINE_TILE__, MINES_OF(1), MINES_OF(0) ])) # 왼쪽 아래 섹션: 오른쪽 아래 끝단 2개 지뢰 self.left_bottom_section = Section(Point(-1, 0), data=bytearray([ MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(1), MINES_OF(1), MINES_OF(0), MINES_OF(0), MINES_OF(2), MINE_TILE__, - MINES_OF(0), MINES_OF(0), MINES_OF(2), MINE_TILE__ + MINES_OF(0), MINES_OF(0), MINES_OF(2), OPEN_MINE__ ])) # 오른쪽 아래 섹션: 왼쪽 위 끝단을 감싸는 지뢰들 self.right_bottom_section = Section(Point(0, 0), data=bytearray([ - MINES_OF(3), MINE_TILE__, MINES_OF(2), MINES_OF(0), + MINES_OF(3), OPEN_MINE__, MINES_OF(2), MINES_OF(0), MINE_TILE__, MINE_TILE__, MINES_OF(2), MINES_OF(0), MINES_OF(2), MINES_OF(2), MINES_OF(1), MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(0), MINES_OF(0) @@ -276,7 +276,7 @@ def test_apply_neighbor_num_overflow_left_right(self): self.right_top_section.data[8], # x=0, y=1 self.right_top_section.data[12], # x=0, y=0 ] - self.assertEqual(l.count(MINE_TILE__), 2) + self.assertEqual(l.count(MINE_TILE__) + l.count(OPEN_MINE__), 2) def test_apply_neighbor_num_overflow_right_left(self): self.right_top_section.apply_neighbor_horizontal( @@ -293,10 +293,11 @@ def test_apply_neighbor_num_overflow_right_left(self): self.left_top_section.data[15], # x=3, y=0 ] - self.assertEqual(l.count(MINE_TILE__), 4) + self.assertEqual(l.count(MINE_TILE__) + l.count(OPEN_MINE__), 4) MINE_TILE__ = 0b01000000 +OPEN_MINE__ = 0b11000000 def MINES_OF(n: int) -> int: diff --git a/board/event/handler/internal/board_handler.py b/board/event/handler/internal/board_handler.py index 052e4c2..3eb08c8 100644 --- a/board/event/handler/internal/board_handler.py +++ b/board/event/handler/internal/board_handler.py @@ -1,5 +1,6 @@ +import asyncio from event import EventBroker -from board.data import Point, Tile +from board.data import Point, Tile, Tiles from board.data.handler import BoardHandler from cursor.data import Color from message import Message @@ -9,6 +10,7 @@ TilesEvent, NewConnEvent, NewConnPayload, + NewCursorCandidatePayload, TryPointingPayload, PointingResultPayload, PointEvent, @@ -17,7 +19,9 @@ MovableResultPayload, ClickType, InteractionEvent, - TileStateChangedPayload + TilesOpenedPayload, + SingleTileOpenedPayload, + FlagSetPayload ) @@ -34,14 +38,35 @@ async def receive_fetch_tiles(message: Message[FetchTilesPayload]): async def receive_new_conn(message: Message[NewConnPayload]): sender = message.payload.conn_id - # 0, 0 기준으로 fetch width = message.payload.width height = message.payload.height - start_p = Point(x=-width, y=height) - end_p = Point(x=width, y=-height) + # 커서의 위치 + position = BoardHandler.get_random_open_position() - await BoardEventHandler._publish_tiles(start_p, end_p, [sender]) + start_p = Point( + x=position.x - width, + y=position.y + height + ) + end_p = Point( + x=position.x+width, + y=position.y-height + ) + publish_tiles = BoardEventHandler._publish_tiles(start_p, end_p, [sender]) + + message = Message( + event=NewConnEvent.NEW_CURSOR_CANDIDATE, + payload=NewCursorCandidatePayload( + conn_id=message.payload.conn_id, + width=width, height=height, + position=position + ) + ) + + await asyncio.gather( + publish_tiles, + EventBroker.publish(message) + ) @staticmethod async def _publish_tiles(start: Point, end: Point, to: list[str]): @@ -73,6 +98,7 @@ async def receive_try_pointing(message: Message[TryPointingPayload]): Point(pointer.x+1, pointer.y-1) ) + # 포인팅한 칸 포함 3x3칸 중 열린 칸이 존재하는지 확인 pointable = False for tile in tiles.data: t = Tile.from_int(tile) @@ -80,6 +106,8 @@ async def receive_try_pointing(message: Message[TryPointingPayload]): pointable = True break + publish_coroutines = [] + pub_message = Message( event=PointEvent.POINTING_RESULT, header={"receiver": sender}, @@ -89,54 +117,85 @@ async def receive_try_pointing(message: Message[TryPointingPayload]): ) ) - await EventBroker.publish(pub_message) - - cursor_pos = message.payload.cursor_position + publish_coroutines.append(EventBroker.publish(pub_message)) if not pointable: + await asyncio.gather(*publish_coroutines) return + cursor_pos = message.payload.cursor_position + # 인터랙션 범위 체크 if \ pointer.x < cursor_pos.x - 1 or \ pointer.x > cursor_pos.x + 1 or \ pointer.y < cursor_pos.y - 1 or \ pointer.y > cursor_pos.y + 1: + await asyncio.gather(*publish_coroutines) return # 보드 상태 업데이트하기 - tile = Tile.from_int(tiles.data[4]) # 3x3칸 중 가운데 + tile = Tile.from_int(tiles.data[4]) # 3x3칸 중 가운데 = 포인팅한 타일 click_type = message.payload.click_type if tile.is_open: + await asyncio.gather(*publish_coroutines) return match (click_type): # 닫힌 타일 열기 case ClickType.GENERAL_CLICK: if tile.is_flag: + await asyncio.gather(*publish_coroutines) return - tile.is_open = True + if tile.number is None: + # 빈 칸. 주변 칸 모두 열기. + start_p, end_p, tiles = BoardHandler.open_tiles_cascade(pointer) + tiles.hide_info() + tile_str = tiles.to_str() + + pub_message = Message( + event=InteractionEvent.TILES_OPENED, + payload=TilesOpenedPayload( + start_p=start_p, + end_p=end_p, + tiles=tile_str + ) + ) + publish_coroutines.append(EventBroker.publish(pub_message)) + else: + tile = BoardHandler.open_tile(pointer) + + tile_str = Tiles(data=bytearray([tile.data])).to_str() + + pub_message = Message( + event=InteractionEvent.SINGLE_TILE_OPENED, + payload=SingleTileOpenedPayload( + position=pointer, + tile=tile_str + ) + ) + publish_coroutines.append(EventBroker.publish(pub_message)) # 깃발 꽂기/뽑기 case ClickType.SPECIAL_CLICK: - color = message.payload.color - - tile.is_flag = not tile.is_flag - tile.color = color if tile.is_flag else None - - BoardHandler.update_tile(pointer, tile) - - pub_message = Message( - event=InteractionEvent.TILE_STATE_CHANGED, - payload=TileStateChangedPayload( - position=pointer, - tile=tile - ) - ) - - await EventBroker.publish(pub_message) + flag_state = not tile.is_flag + color = message.payload.color if flag_state else None + + _ = BoardHandler.set_flag_state(p=pointer, state=flag_state, color=color) + + pub_message = Message( + event=InteractionEvent.FLAG_SET, + payload=FlagSetPayload( + position=pointer, + is_set=flag_state, + color=color, + ) + ) + publish_coroutines.append(EventBroker.publish(pub_message)) + + await asyncio.gather(*publish_coroutines) @EventBroker.add_receiver(MoveEvent.CHECK_MOVABLE) @staticmethod diff --git a/board/event/handler/test/board_handler_test.py b/board/event/handler/test/board_handler_test.py index 4fa1048..42ac6aa 100644 --- a/board/event/handler/test/board_handler_test.py +++ b/board/event/handler/test/board_handler_test.py @@ -1,3 +1,4 @@ +import asyncio from cursor.data import Color from board.data import Point, Tile, Tiles from board.event.handler import BoardEventHandler @@ -10,6 +11,7 @@ TilesPayload, NewConnEvent, NewConnPayload, + NewCursorCandidatePayload, TryPointingPayload, PointingResultPayload, PointEvent, @@ -18,7 +20,9 @@ CheckMovablePayload, MovableResultPayload, InteractionEvent, - TileStateChangedPayload + SingleTileOpenedPayload, + TilesOpenedPayload, + FlagSetPayload ) import unittest @@ -124,41 +128,50 @@ async def test_fetch_tiles_receiver_normal_case(self, mock: AsyncMock): @patch("event.EventBroker.publish") async def test_receive_new_conn(self, mock: AsyncMock): conn_id = "ayo" + width = 1 + height = 1 message = Message( event=NewConnEvent.NEW_CONN, - payload=NewConnPayload(conn_id=conn_id, width=1, height=1) + payload=NewConnPayload(conn_id=conn_id, width=width, height=height) ) await BoardEventHandler.receive_new_conn(message) - mock.assert_called_once() - got: Message[TilesPayload] = mock.mock_calls[0].args[0] + # tiles, new-cursor-candidate + self.assertEqual(len(mock.mock_calls), 2) + # new-cursor-candidate + got: Message[NewCursorCandidatePayload] = mock.mock_calls[0].args[0] self.assertEqual(type(got), Message) - self.assertEqual(got.event, "multicast") + self.assertEqual(got.event, NewConnEvent.NEW_CURSOR_CANDIDATE) + + self.assertEqual(type(got.payload), NewCursorCandidatePayload) + self.assertEqual(got.payload.conn_id, conn_id) + self.assertEqual(got.payload.width, width) + self.assertEqual(got.payload.height, height) + position = got.payload.position + tiles = BoardHandler.fetch(position, position) + tile = Tile.from_int(tiles.data[0]) + self.assertTrue(tile.is_open) + + # tiles + got: Message[TilesPayload] = mock.mock_calls[1].args[0] + self.assertEqual(type(got), Message) + self.assertEqual(got.event, "multicast") self.assertIn("target_conns", got.header) self.assertEqual(len(got.header["target_conns"]), 1) - self.assertEqual(got.header["target_conns"][0], conn_id) + self.assertIn("origin_event", got.header) + self.assertEqual(got.header["origin_event"], TilesEvent.TILES) self.assertEqual(type(got.payload), TilesPayload) - self.assertEqual(got.payload.start_p.x, -1) - self.assertEqual(got.payload.start_p.y, 1) - self.assertEqual(got.payload.end_p.x, 1) - self.assertEqual(got.payload.end_p.y, -1) + self.assertEqual(got.payload.start_p, Point(position.x-width, position.y+height)) + self.assertEqual(got.payload.end_p, Point(position.x+width, position.y-height)) # 하는 김에 마스킹까지 같이 테스트 - empty_open = Tile.from_int(0b10000000) - one_open = Tile.from_int(0b10000001) - closed = Tile.from_int(0b00000000) - blue_flag = Tile.from_int(0b00110000) - purple_flag = Tile.from_int(0b00111000) + expected = BoardHandler.fetch(got.payload.start_p, got.payload.end_p) + expected.hide_info() - expected = Tiles(data=bytearray([ - one_open.data, one_open.data, blue_flag.data, - empty_open.data, one_open.data, closed.data, - one_open.data, one_open.data, purple_flag.data - ])) self.assertEqual(got.payload.tiles, expected.to_str()) @@ -219,7 +232,7 @@ async def test_try_pointing_pointable_closed_general_click(self, mock: AsyncMock await BoardEventHandler.receive_try_pointing(message) - # pointing-result, tile-state-changed 발행하는지 확인 + # pointing-result, single-tile-opened 발행하는지 확인 self.assertEqual(len(mock.mock_calls), 2) # pointing-result @@ -235,12 +248,12 @@ async def test_try_pointing_pointable_closed_general_click(self, mock: AsyncMock self.assertTrue(got.payload.pointable) self.assertEqual(got.payload.pointer, pointer) - # tile-state-changed - got: Message[PointingResultPayload] = mock.mock_calls[1].args[0] + # single-tile-opened + got: Message[SingleTileOpenedPayload] = mock.mock_calls[1].args[0] self.assertEqual(type(got), Message) - self.assertEqual(got.event, InteractionEvent.TILE_STATE_CHANGED) + self.assertEqual(got.event, InteractionEvent.SINGLE_TILE_OPENED) # payload 확인 - self.assertEqual(type(got.payload), TileStateChangedPayload) + self.assertEqual(type(got.payload), SingleTileOpenedPayload) self.assertEqual(got.payload.position, pointer) expected_tile = Tile.create( @@ -250,10 +263,40 @@ async def test_try_pointing_pointable_closed_general_click(self, mock: AsyncMock color=None, number=1 ) - fetched_tile = Tile.from_int(BoardHandler.fetch(start=pointer, end=pointer).data[0]) + tiles = BoardHandler.fetch(start=pointer, end=pointer) + fetched_tile = Tile.from_int(tiles.data[0]) self.assertEqual(fetched_tile, expected_tile) - self.assertEqual(got.payload.tile, expected_tile) + self.assertEqual(got.payload.tile, tiles.to_str()) + + @patch("event.EventBroker.publish") + async def test_try_pointing_pointable_closed_general_click_race(self, mock: AsyncMock): + cursor_pos = Point(0, 0) + pointer = Point(1, 0) + + # 코루틴 스위칭을 위해 sleep. 이게 되는 이유를 모르겠다. + async def sleep(_): + await asyncio.sleep(0) + mock.side_effect = sleep + + message = Message( + event=PointEvent.TRY_POINTING, + header={"sender": self.sender_id}, + payload=TryPointingPayload( + cursor_position=cursor_pos, + new_pointer=pointer, + click_type=ClickType.GENERAL_CLICK, + color=Color.BLUE + ) + ) + + await asyncio.gather( + BoardEventHandler.receive_try_pointing(message), + BoardEventHandler.receive_try_pointing(message) + ) + + # 첫번째: pointing-result, single-tile-opened 두번째: pointing-result 발행하는지 확인 + self.assertEqual(len(mock.mock_calls), 3) @patch("event.EventBroker.publish") async def test_try_pointing_pointable_closed_general_click_flag(self, mock: AsyncMock): @@ -308,7 +351,7 @@ async def test_try_pointing_pointable_closed_special_click(self, mock: AsyncMock await BoardEventHandler.receive_try_pointing(message) - # pointing-result, tile-state-changed 발행하는지 확인 + # pointing-result, flag-set 발행하는지 확인 self.assertEqual(len(mock.mock_calls), 2) # pointing-result @@ -324,13 +367,15 @@ async def test_try_pointing_pointable_closed_special_click(self, mock: AsyncMock self.assertTrue(got.payload.pointable) self.assertEqual(got.payload.pointer, pointer) - # tile-state-changed - got: Message[PointingResultPayload] = mock.mock_calls[1].args[0] + # flag-set + got: Message[FlagSetPayload] = mock.mock_calls[1].args[0] self.assertEqual(type(got), Message) - self.assertEqual(got.event, InteractionEvent.TILE_STATE_CHANGED) + self.assertEqual(got.event, InteractionEvent.FLAG_SET) # payload 확인 - self.assertEqual(type(got.payload), TileStateChangedPayload) + self.assertEqual(type(got.payload), FlagSetPayload) self.assertEqual(got.payload.position, pointer) + self.assertEqual(got.payload.color, color) + self.assertTrue(got.payload.is_set) expected_tile = Tile.create( is_open=False, @@ -343,7 +388,6 @@ async def test_try_pointing_pointable_closed_special_click(self, mock: AsyncMock fetched_tile = Tile.from_int(BoardHandler.fetch(start=pointer, end=pointer).data[0]) self.assertEqual(fetched_tile, expected_tile) - self.assertEqual(got.payload.tile, expected_tile) @patch("event.EventBroker.publish") async def test_try_pointing_pointable_closed_special_click_already_flag(self, mock: AsyncMock): @@ -364,7 +408,7 @@ async def test_try_pointing_pointable_closed_special_click_already_flag(self, mo await BoardEventHandler.receive_try_pointing(message) - # pointing-result, tile-state-changed 발행하는지 확인 + # pointing-result, flag-set 발행하는지 확인 self.assertEqual(len(mock.mock_calls), 2) # pointing-result @@ -380,13 +424,15 @@ async def test_try_pointing_pointable_closed_special_click_already_flag(self, mo self.assertTrue(got.payload.pointable) self.assertEqual(got.payload.pointer, pointer) - # tile-state-changed - got: Message[PointingResultPayload] = mock.mock_calls[1].args[0] + # flag-set + got: Message[FlagSetPayload] = mock.mock_calls[1].args[0] self.assertEqual(type(got), Message) - self.assertEqual(got.event, InteractionEvent.TILE_STATE_CHANGED) + self.assertEqual(got.event, InteractionEvent.FLAG_SET) # payload 확인 - self.assertEqual(type(got.payload), TileStateChangedPayload) + self.assertEqual(type(got.payload), FlagSetPayload) self.assertEqual(got.payload.position, pointer) + self.assertIsNone(got.payload.color) + self.assertFalse(got.payload.is_set) expected_tile = Tile.create( is_open=False, @@ -399,7 +445,6 @@ async def test_try_pointing_pointable_closed_special_click_already_flag(self, mo fetched_tile = Tile.from_int(BoardHandler.fetch(start=pointer, end=pointer).data[0]) self.assertEqual(fetched_tile, expected_tile) - self.assertEqual(got.payload.tile, expected_tile) @patch("event.EventBroker.publish") async def test_try_pointing_not_pointable(self, mock: AsyncMock): diff --git a/conn/internal/conn.py b/conn/internal/conn.py index 2d2ebf3..99fa28f 100644 --- a/conn/internal/conn.py +++ b/conn/internal/conn.py @@ -1,4 +1,5 @@ from fastapi.websockets import WebSocket +from websockets.exceptions import ConnectionClosed from message import Message from dataclasses import dataclass @@ -22,4 +23,8 @@ async def receive(self): return Message.from_str(await self.conn.receive_text()) async def send(self, msg: Message): - await self.conn.send_text(msg.to_str()) + try: + await self.conn.send_text(msg.to_str()) + except ConnectionClosed: + # 커넥션이 종료되었는데도 타이밍 문제로 인해 커넥션을 가져왔을 수 있음. + return diff --git a/conn/manager/internal/connection_manager.py b/conn/manager/internal/connection_manager.py index e630e13..6130b9b 100644 --- a/conn/manager/internal/connection_manager.py +++ b/conn/manager/internal/connection_manager.py @@ -1,3 +1,4 @@ +import asyncio from fastapi.websockets import WebSocket from conn import Conn from message import Message @@ -65,9 +66,14 @@ def generate_conn_id(): @staticmethod async def receive_broadcast_event(message: Message): overwrite_event(message) + + coroutines = [] + for id in ConnectionManager.conns: conn = ConnectionManager.conns[id] - await conn.send(message) + coroutines.append(conn.send(message)) + + await asyncio.gather(*coroutines) @EventBroker.add_receiver("multicast") @staticmethod @@ -75,12 +81,17 @@ async def receive_multicast_event(message: Message): overwrite_event(message) if "target_conns" not in message.header: raise DumbHumanException() + + coroutines = [] + for conn_id in message.header["target_conns"]: conn = ConnectionManager.get_conn(conn_id) if not conn: raise DumbHumanException() - await conn.send(message) + coroutines.append(conn.send(message)) + + await asyncio.gather(*coroutines) @staticmethod async def handle_message(message: Message): diff --git a/cursor/data/handler/internal/cursor_handler.py b/cursor/data/handler/internal/cursor_handler.py index 888173a..3036678 100644 --- a/cursor/data/handler/internal/cursor_handler.py +++ b/cursor/data/handler/internal/cursor_handler.py @@ -15,8 +15,10 @@ class CursorHandler: watching: dict[str, list[str]] = {} @staticmethod - def create_cursor(conn_id: str): + def create_cursor(conn_id: str, position: Point, width: int, height: int): cursor = Cursor.create(conn_id) + cursor.position = position + cursor.set_size(width=width, height=height) CursorHandler.cursor_dict[conn_id] = cursor @@ -34,41 +36,79 @@ def get_cursor(conn_id: str) -> Cursor | None: # range 안에 커서가 있는가 @staticmethod - def exists_range(start: Point, end: Point, *exclude_ids) -> list[Cursor]: + def exists_range( + start: Point, end: Point, exclude_ids: list[str] = [], + exclude_start: Point | None = None, exclude_end: Point | None = None + ) -> list[Cursor]: result = [] - for key in CursorHandler.cursor_dict: - if exclude_ids and key in exclude_ids: + for cursor_id in CursorHandler.cursor_dict: + if cursor_id in exclude_ids: continue - cur = CursorHandler.cursor_dict[key] - if start.x > cur.position.x: - continue - if end.x < cur.position.x: - continue - if start.y < cur.position.y: - continue - if end.y > cur.position.y: + + cursor = CursorHandler.cursor_dict[cursor_id] + pos = cursor.position + # start & end 범위를 벗어나는가 + if \ + start.x > pos.x or end.x < pos.x or \ + end.y > pos.y or start.y < pos.y: continue - result.append(cur) + + # exclude_range 범위에 들어가는가 + if exclude_start is not None and exclude_end is not None: + if \ + pos.x >= exclude_start.x and pos.x <= exclude_end.x and \ + pos.y >= exclude_end.y and pos.y <= exclude_start.y: + continue + + result.append(cursor) return result # 커서 view에 tile이 포함되는가 @staticmethod - def view_includes(p: Point, *exclude_ids) -> list[Cursor]: + def view_includes_point(p: Point, exclude_ids: list[str] = []) -> list[Cursor]: result = [] - for key in CursorHandler.cursor_dict: - if exclude_ids and key in exclude_ids: + for cursor_id in CursorHandler.cursor_dict: + if cursor_id in exclude_ids: continue - cur = CursorHandler.cursor_dict[key] - if (cur.position.x - cur.width) > p.x: + + cursor = CursorHandler.cursor_dict[cursor_id] + + # 커서 뷰 범위를 벗어나는가 + if not cursor.check_in_view(p): continue - if (cur.position.x + cur.width) < p.x: + + result.append(cursor) + + return result + + # 커서 view에 range가 포함되는가 + @staticmethod + def view_includes_range(start: Point, end: Point, exclude_ids: list[str] = []) -> list[Cursor]: + result = [] + for cursor_id in CursorHandler.cursor_dict: + if cursor_id in exclude_ids: continue - if (cur.position.y - cur.height) > p.y: + + cursor = CursorHandler.cursor_dict[cursor_id] + + left_top = Point( + x=cursor.position.x - cursor.width, + y=cursor.position.y + cursor.height + ) + right_bottom = Point( + x=cursor.position.x + cursor.width, + y=cursor.position.y - cursor.height + ) + + # left_top이 end보다 오른쪽 혹은 아래인가 + if left_top.x > end.x or left_top.y < end.y: continue - if (cur.position.y + cur.height) < p.y: + # right_bottom이 start 보다 왼쪽 혹은 위인가 + if right_bottom.x < start.x or right_bottom.y > start.y: continue - result.append(cur) + + result.append(cursor) return result diff --git a/cursor/data/handler/test/cursor_handler_test.py b/cursor/data/handler/test/cursor_handler_test.py index 41eb961..8b4b9f6 100644 --- a/cursor/data/handler/test/cursor_handler_test.py +++ b/cursor/data/handler/test/cursor_handler_test.py @@ -50,11 +50,17 @@ def tearDown(self): def test_create(self): conn_id = "example_conn_id" - _ = CursorHandler.create_cursor(conn_id) + width, height = 10, 10 + position = Point(1, 1) + + _ = CursorHandler.create_cursor(conn_id, position, width, height) self.assertIn(conn_id, CursorHandler.cursor_dict) self.assertEqual(type(CursorHandler.cursor_dict[conn_id]), Cursor) self.assertEqual(CursorHandler.cursor_dict[conn_id].conn_id, conn_id) + self.assertEqual(CursorHandler.cursor_dict[conn_id].width, width) + self.assertEqual(CursorHandler.cursor_dict[conn_id].height, height) + self.assertEqual(CursorHandler.cursor_dict[conn_id].position, position) def test_get_cursor(self): a_cur: Cursor | None = CursorHandler.get_cursor("A") @@ -78,20 +84,72 @@ def test_remove_cursor(self): self.assertEqual(len(CursorHandler.cursor_dict), 2) def test_exists_range(self): - result = CursorHandler.exists_range(Point(-3, 3), Point(3, -3)) - result.sort(key=lambda c: c.conn_id) + result = CursorHandler.exists_range(start=Point(-3, 3), end=Point(3, -3)) + + result = [c.conn_id for c in result] + + self.assertEqual(len(result), 2) + self.assertIn("A", result) + self.assertIn("C", result) + + def test_exists_range_exclude_id(self): + result = CursorHandler.exists_range(start=Point(-3, 3), end=Point(3, -3), exclude_ids=["A"]) + + result = [c.conn_id for c in result] + + self.assertEqual(len(result), 1) + self.assertIn("C", result) + + def test_exists_range_exclude_range(self): + result = CursorHandler.exists_range( + start=Point(-3, 3), end=Point(3, -3), + exclude_start=Point(-4, 3), exclude_end=Point(0, 1) + ) + + result = [c.conn_id for c in result] + + self.assertEqual(len(result), 1) + self.assertIn("C", result) + + def test_view_includes_point(self): + result = CursorHandler.view_includes_point(p=Point(-3, 0)) + + result = [c.conn_id for c in result] self.assertEqual(len(result), 2) - self.assertEqual(result[0].conn_id, "A") - self.assertEqual(result[1].conn_id, "C") + self.assertIn("A", result) + self.assertIn("B", result) + + def test_view_includes_point_exclude_id(self): + result = CursorHandler.view_includes_point(p=Point(-3, 0), exclude_ids=["A"]) + + result = [c.conn_id for c in result] + + self.assertEqual(len(result), 1) + self.assertIn("B", result) + + def test_view_includes_range(self): + start = Point(-3, 1) + end = Point(-2, 0) + result = CursorHandler.view_includes_range(start=start, end=end) + + result = [c.conn_id for c in result] + + self.assertEqual(len(result), 3) + self.assertIn("A", result) + self.assertIn("B", result) + self.assertIn("C", result) + + def test_view_includes_range__exclude_id(self): + start = Point(-3, 1) + end = Point(-2, 0) + result = CursorHandler.view_includes_range(start=start, end=end, exclude_ids=["A"]) - def test_view_includes(self): - result = CursorHandler.view_includes(Point(-3, 0)) - result.sort(key=lambda c: c.conn_id) + result = [c.conn_id for c in result] self.assertEqual(len(result), 2) - self.assertEqual(result[0].conn_id, "A") - self.assertEqual(result[1].conn_id, "B") + self.assertIn("B", result) + self.assertIn("C", result) def test_add_watcher(self): CursorHandler.add_watcher( diff --git a/cursor/event/handler/internal/cursor_event_handler.py b/cursor/event/handler/internal/cursor_event_handler.py index 1d87672..5013857 100644 --- a/cursor/event/handler/internal/cursor_event_handler.py +++ b/cursor/event/handler/internal/cursor_event_handler.py @@ -1,6 +1,7 @@ +import asyncio from cursor.data import Cursor from cursor.data.handler import CursorHandler -from board.data import Point, Tile +from board.data import Point, Tile, Tiles from event import EventBroker from message import Message from datetime import datetime, timedelta @@ -21,23 +22,30 @@ MovableResultPayload, MovedPayload, InteractionEvent, - TileStateChangedPayload, - TileUpdatedPayload, + FlagSetPayload, + SingleTileOpenedPayload, + TilesOpenedPayload, YouDiedPayload, ConnClosedPayload, CursorQuitPayload, SetViewSizePayload, ErrorEvent, - ErrorPayload + ErrorPayload, + NewCursorCandidatePayload ) class CursorEventHandler: - @EventBroker.add_receiver(NewConnEvent.NEW_CONN) + @EventBroker.add_receiver(NewConnEvent.NEW_CURSOR_CANDIDATE) @staticmethod - async def receive_new_conn(message: Message[NewConnPayload]): - cursor = CursorHandler.create_cursor(message.payload.conn_id) - cursor.set_size(message.payload.width, message.payload.height) + async def receive_new_cursor_candidate(message: Message[NewCursorCandidatePayload]): + cursor = CursorHandler.create_cursor( + conn_id=message.payload.conn_id, + position=message.payload.position, + width=message.payload.width, height=message.payload.height + ) + + publish_coroutines = [] new_cursor_message = Message( event="multicast", @@ -50,7 +58,7 @@ async def receive_new_conn(message: Message[NewConnPayload]): ) ) - await EventBroker.publish(new_cursor_message) + publish_coroutines.append(EventBroker.publish(new_cursor_message)) start_p = Point( x=cursor.position.x - cursor.width, @@ -61,28 +69,34 @@ async def receive_new_conn(message: Message[NewConnPayload]): y=cursor.position.y - cursor.height ) - cursors_in_range = CursorHandler.exists_range(start_p, end_p, cursor.conn_id) + cursors_in_range = CursorHandler.exists_range(start=start_p, end=end_p, exclude_ids=[cursor.conn_id]) if len(cursors_in_range) > 0: # 내가 보고있는 커서들 for other_cursor in cursors_in_range: CursorHandler.add_watcher(watcher=cursor, watching=other_cursor) - await publish_new_cursors_event( - target_cursors=[cursor], - cursors=cursors_in_range + publish_coroutines.append( + publish_new_cursors_event( + target_cursors=[cursor], + cursors=cursors_in_range + ) ) - cursors_with_view_including = CursorHandler.view_includes(cursor.position, cursor.conn_id) + cursors_with_view_including = CursorHandler.view_includes_point(p=cursor.position, exclude_ids=[cursor.conn_id]) if len(cursors_with_view_including) > 0: # 나를 보고있는 커서들 for other_cursor in cursors_with_view_including: CursorHandler.add_watcher(watcher=other_cursor, watching=cursor) - await publish_new_cursors_event( - target_cursors=cursors_with_view_including, - cursors=[cursor] + publish_coroutines.append( + publish_new_cursors_event( + target_cursors=cursors_with_view_including, + cursors=[cursor] + ) ) + await asyncio.gather(*publish_coroutines) + @EventBroker.add_receiver(PointEvent.POINTING) @staticmethod async def receive_pointing(message: Message[PointingPayload]): @@ -219,12 +233,11 @@ async def receive_movable_result(message: Message[MovableResultPayload]): cursor.position = new_position # TODO: 새로운 방식으로 커서들 찾기. 최적화하기. - # set을 사용하면 제약이 있음. # 새로운 뷰의 커서들 찾기 top_left = Point(cursor.position.x - cursor.width, cursor.position.y + cursor.height) bottom_right = Point(cursor.position.x + cursor.width, cursor.position.y - cursor.height) - cursors_in_view = CursorHandler.exists_range(top_left, bottom_right, cursor.conn_id) + cursors_in_view = CursorHandler.exists_range(start=top_left, end=bottom_right, exclude_ids=[cursor.conn_id]) original_watching_ids = CursorHandler.get_watching(cursor_id=cursor.conn_id) original_watchings = [CursorHandler.get_cursor(id) for id in original_watching_ids] @@ -236,6 +249,8 @@ async def receive_movable_result(message: Message[MovableResultPayload]): if not in_view: CursorHandler.remove_watcher(watcher=cursor, watching=other_cursor) + publish_coroutines = [] + new_watchings = list(filter(lambda c: c.conn_id not in original_watching_ids, cursors_in_view)) if len(new_watchings) > 0: # 새로운 watching 커서들 연관관계 설정 @@ -243,13 +258,15 @@ async def receive_movable_result(message: Message[MovableResultPayload]): CursorHandler.add_watcher(watcher=cursor, watching=other_cursor) # 새로운 커서들 전달 - await publish_new_cursors_event( - target_cursors=[cursor], - cursors=new_watchings + publish_coroutines.append( + publish_new_cursors_event( + target_cursors=[cursor], + cursors=new_watchings + ) ) # 새로운 위치를 바라보고 있는 커서들 찾기, 본인 제외 - watchers_new_pos = CursorHandler.view_includes(new_position, cursor.conn_id) + watchers_new_pos = CursorHandler.view_includes_point(p=new_position, exclude_ids=[cursor.conn_id]) original_watcher_ids = CursorHandler.get_watchers(cursor_id=cursor.conn_id) original_watchers = [CursorHandler.get_cursor(id) for id in original_watcher_ids] @@ -268,7 +285,8 @@ async def receive_movable_result(message: Message[MovableResultPayload]): color=cursor.color, ) ) - await EventBroker.publish(message) + + publish_coroutines.append(EventBroker.publish(message)) # 범위 벗어나면 watcher 제거 for watcher in original_watchers: @@ -283,45 +301,50 @@ async def receive_movable_result(message: Message[MovableResultPayload]): CursorHandler.add_watcher(watcher=other_cursor, watching=cursor) # 새로운 커서들에게 본인 커서 전달 - await publish_new_cursors_event( - target_cursors=new_watchers, - cursors=[cursor] + publish_coroutines.append( + publish_new_cursors_event( + target_cursors=new_watchers, + cursors=[cursor] + ) ) - @EventBroker.add_receiver(InteractionEvent.TILE_STATE_CHANGED) + await asyncio.gather(*publish_coroutines) + + @EventBroker.add_receiver(InteractionEvent.SINGLE_TILE_OPENED) @staticmethod - async def receive_tile_state_changed(message: Message[TileStateChangedPayload]): + async def receive_single_tile_opened(message: Message[SingleTileOpenedPayload]): position = message.payload.position - tile = message.payload.tile + tile_str = message.payload.tile + + tiles = Tiles(data=bytearray.fromhex(tile_str)) + tile = Tile.from_int(tiles.data[0]) - pub_tile = tile - if not tile.is_open: - # 닫힌 타일의 mine, number 정보는 버리기 - pub_tile = tile.copy(hide_info=True) + publish_coroutines = [] # 변경된 타일을 보고있는 커서들에게 전달 - view_cursors = CursorHandler.view_includes(position) + view_cursors = CursorHandler.view_includes_point(p=position) if len(view_cursors) > 0: pub_message = Message( event="multicast", - header={"target_conns": [c.conn_id for c in view_cursors], - "origin_event": InteractionEvent.TILE_UPDATED}, - payload=TileUpdatedPayload( - position=position, - tile=pub_tile - ) + header={ + "target_conns": [c.conn_id for c in view_cursors], + "origin_event": message.event + }, + payload=message.payload ) - await EventBroker.publish(pub_message) + publish_coroutines.append(EventBroker.publish(pub_message)) - if not (tile.is_open and tile.is_mine): + if not tile.is_mine: + await asyncio.gather(*publish_coroutines) return # 주변 8칸 커서들 죽이기 start_p = Point(position.x - 1, position.y + 1) end_p = Point(position.x + 1, position.y - 1) - nearby_cursors = CursorHandler.exists_range(start_p, end_p) + nearby_cursors = CursorHandler.exists_range(start=start_p, end=end_p) if len(nearby_cursors) > 0: + # TODO: 하드코딩 없애기 revive_at = datetime.now() + timedelta(minutes=3) for c in nearby_cursors: @@ -329,11 +352,49 @@ async def receive_tile_state_changed(message: Message[TileStateChangedPayload]): pub_message = Message( event="multicast", - header={"target_conns": [c.conn_id for c in nearby_cursors], - "origin_event": InteractionEvent.YOU_DIED}, - payload=YouDiedPayload( - revive_at=revive_at.astimezone().isoformat() - ) + header={ + "target_conns": [c.conn_id for c in nearby_cursors], + "origin_event": InteractionEvent.YOU_DIED + }, + payload=YouDiedPayload(revive_at=revive_at.astimezone().isoformat()) + ) + publish_coroutines.append(EventBroker.publish(pub_message)) + + await asyncio.gather(*publish_coroutines) + + @staticmethod + async def receive_tiles_opened(message: Message[TilesOpenedPayload]): + start_p = message.payload.start_p + end_p = message.payload.end_p + + # 변경된 타일을 보고있는 커서들에게 전달 + view_cursors = CursorHandler.view_includes_range(start=start_p, end=end_p) + if len(view_cursors) > 0: + pub_message = Message( + event="multicast", + header={ + "target_conns": [c.conn_id for c in view_cursors], + "origin_event": message.event + }, + payload=message.payload + ) + await EventBroker.publish(pub_message) + + @EventBroker.add_receiver(InteractionEvent.FLAG_SET) + @staticmethod + async def receive_flag_set(message: Message[FlagSetPayload]): + position = message.payload.position + + # 변경된 타일을 보고있는 커서들에게 전달 + view_cursors = CursorHandler.view_includes_point(p=position) + if len(view_cursors) > 0: + pub_message = Message( + event="multicast", + header={ + "target_conns": [c.conn_id for c in view_cursors], + "origin_event": message.event + }, + payload=message.payload ) await EventBroker.publish(pub_message) @@ -383,20 +444,33 @@ async def receive_set_view_size(message: Message[SetViewSizePayload]): cur_watching = CursorHandler.get_watching(cursor_id=cursor.conn_id) - size_grown = new_width > cursor.width or new_height > cursor.height + old_width, old_height = cursor.width, cursor.height cursor.set_size(new_width, new_height) + size_grown = (new_width > old_width) or (new_height > old_height) + if size_grown: - top_left = Point(x=cursor.position.x - cursor.width, y=cursor.position.y + cursor.height) - bottom_right = Point(x=cursor.position.x + cursor.width, y=cursor.position.y - cursor.height) + pos = cursor.position - exclude_list = [cursor.conn_id] + cur_watching - new_watchings = CursorHandler.exists_range(top_left, bottom_right, *exclude_list) + # 현재 범위 + old_top_left = Point(x=pos.x - old_width, y=pos.y + old_height) + old_bottom_right = Point(x=pos.x + old_width, y=pos.y - old_height) - for other_cursor in new_watchings: - CursorHandler.add_watcher(watcher=cursor, watching=other_cursor) + # 새로운 범위 + new_top_left = Point(x=pos.x - new_width, y=pos.y + new_height) + new_bottom_right = Point(x=pos.x + new_width, y=pos.y - new_height) + + # 현재 범위를 제외한 새로운 범위에서 커서들 가져오기 + new_watchings = CursorHandler.exists_range( + start=new_top_left, end=new_bottom_right, + exclude_start=old_top_left, exclude_end=old_bottom_right + ) + + if len(new_watchings) > 0: + for other_cursor in new_watchings: + CursorHandler.add_watcher(watcher=cursor, watching=other_cursor) - await publish_new_cursors_event(target_cursors=[cursor], cursors=new_watchings) + await publish_new_cursors_event(target_cursors=[cursor], cursors=new_watchings) for id in cur_watching: other_cursor = CursorHandler.get_cursor(id) diff --git a/cursor/event/handler/test/__init__.py b/cursor/event/handler/test/__init__.py index 45401da..09512b0 100644 --- a/cursor/event/handler/test/__init__.py +++ b/cursor/event/handler/test/__init__.py @@ -1,5 +1,5 @@ from .cursor_event_handler_test import ( - CursorEventHandler_NewConnReceiver_TestCase, + CursorEventHandler_NewCursorCandidateReceiver_TestCase, CursorEventHandler_PointingReceiver_TestCase, CursorEventHandler_MovingReceiver_TestCase, CursorEventHandler_TileStateChanged_TestCase, diff --git a/cursor/event/handler/test/cursor_event_handler_test.py b/cursor/event/handler/test/cursor_event_handler_test.py index e941cc1..fe49b78 100644 --- a/cursor/event/handler/test/cursor_event_handler_test.py +++ b/cursor/event/handler/test/cursor_event_handler_test.py @@ -1,10 +1,10 @@ +import asyncio from cursor.data import Cursor, Color from cursor.data.handler import CursorHandler from cursor.event.handler import CursorEventHandler from message import Message from message.payload import ( NewConnEvent, - NewConnPayload, MyCursorPayload, CursorsPayload, PointEvent, @@ -19,19 +19,21 @@ MovableResultPayload, MovedPayload, InteractionEvent, - TileStateChangedPayload, YouDiedPayload, - TileUpdatedPayload, + SingleTileOpenedPayload, + TilesOpenedPayload, + FlagSetPayload, ConnClosedPayload, CursorQuitPayload, SetViewSizePayload, ErrorEvent, - ErrorPayload + ErrorPayload, + NewCursorCandidatePayload ) from .fixtures import setup_cursor_locations import unittest from unittest.mock import AsyncMock, patch -from board.data import Point, Tile +from board.data import Point, Tile, Tiles """ CursorEventHandler Test @@ -40,7 +42,7 @@ ✅ : test 통과 ❌ : test 실패 🖊️ : test 작성 -- new-conn-receiver +- new-cursor-receiver - ✅| normal-case - ✅| without-cursors - 작성해야함 @@ -55,16 +57,16 @@ """ -class CursorEventHandler_NewConnReceiver_TestCase(unittest.IsolatedAsyncioTestCase): +class CursorEventHandler_NewCursorCandidateReceiver_TestCase(unittest.IsolatedAsyncioTestCase): def tearDown(self): CursorHandler.cursor_dict = {} CursorHandler.watchers = {} CursorHandler.watching = {} @patch("event.EventBroker.publish") - async def test_new_conn_receive_without_cursors(self, mock: AsyncMock): + async def test_new_cursor_candidate_receive_without_cursors(self, mock: AsyncMock): """ - new-conn-receiver + new-cursor-candidate-receiver without-cursors description: @@ -72,7 +74,7 @@ async def test_new_conn_receive_without_cursors(self, mock: AsyncMock): ---------------------------- trigger event -> - - new-conn : message[NewConnPayload] + - new-cursor-candidate : message[NewConnPayload] - header : - sender : conn_id - descrption : @@ -109,19 +111,21 @@ async def test_new_conn_receive_without_cursors(self, mock: AsyncMock): expected_conn_id = "example" expected_height = 100 expected_width = 100 + position = Point(1, 1) # trigger message 생성 message = Message( - event=NewConnEvent.NEW_CONN, - payload=NewConnPayload( + event=NewConnEvent.NEW_CURSOR_CANDIDATE, + payload=NewCursorCandidatePayload( conn_id=expected_conn_id, width=expected_width, - height=expected_height + height=expected_height, + position=position ) ) # trigger event - await CursorEventHandler.receive_new_conn(message) + await CursorEventHandler.receive_new_cursor_candidate(message) # 호출 여부 self.assertEqual(len(mock.mock_calls), 1) @@ -141,12 +145,40 @@ async def test_new_conn_receive_without_cursors(self, mock: AsyncMock): # message.payload self.assertEqual(type(got.payload), MyCursorPayload) self.assertIsNone(got.payload.pointer) - self.assertEqual(got.payload.position.x, 0) - self.assertEqual(got.payload.position.y, 0) + self.assertEqual(got.payload.position, position) self.assertIn(got.payload.color, Color) @patch("event.EventBroker.publish") - async def test_receive_new_conn_with_cursors(self, mock: AsyncMock): + async def test_new_cursor_candidate_receive_without_cursors_race(self, mock: AsyncMock): + conn_1 = "1" + conn_2 = "2" + height = 1 + width = 1 + position = Point(0, 0) + + new_cursor_1_msg = Message( + event=NewConnEvent.NEW_CURSOR_CANDIDATE, + payload=NewCursorCandidatePayload(conn_id=conn_1, width=width, height=height, position=position) + ) + new_cursor_2_msg = Message( + event=NewConnEvent.NEW_CURSOR_CANDIDATE, + payload=NewCursorCandidatePayload(conn_id=conn_2, width=width, height=height, position=position) + ) + + # 코루틴 스위칭을 위해 sleep. 이게 되는 이유를 모르겠다. + async def sleep(_): + await asyncio.sleep(0) + mock.side_effect = sleep + + await asyncio.gather( + CursorEventHandler.receive_new_cursor_candidate(new_cursor_1_msg), + CursorEventHandler.receive_new_cursor_candidate(new_cursor_2_msg) + ) + # 첫번째 conn: my-cursor, 두번째 conn: my-cursor, cursors * 2 + self.assertEqual(len(mock.mock_calls), 4) + + @patch("event.EventBroker.publish") + async def test_receive_new_cursor_candidate_with_cursors(self, mock: AsyncMock): # /docs/example/cursor-location.png # But B is at 0,0 CursorHandler.cursor_dict = { @@ -178,17 +210,19 @@ async def test_receive_new_conn_with_cursors(self, mock: AsyncMock): new_conn_id = "B" height = 7 width = 7 + position = Point(0, 0) message = Message( - event=NewConnEvent.NEW_CONN, - payload=NewConnPayload( + event=NewConnEvent.NEW_CURSOR_CANDIDATE, + payload=NewCursorCandidatePayload( conn_id=new_conn_id, width=height, - height=width + height=width, + position=position ) ) - await CursorEventHandler.receive_new_conn(message) + await CursorEventHandler.receive_new_cursor_candidate(message) # publish 횟수 self.assertEqual(len(mock.mock_calls), 3) @@ -206,7 +240,7 @@ async def test_receive_new_conn_with_cursors(self, mock: AsyncMock): self.assertEqual(got.header["origin_event"], NewConnEvent.MY_CURSOR) # payload 확인 self.assertEqual(type(got.payload), MyCursorPayload) - self.assertEqual(got.payload.position, Point(0, 0)) + self.assertEqual(got.payload.position, position) self.assertIsNone(got.payload.pointer) self.assertIn(got.payload.color, Color) @@ -767,7 +801,7 @@ async def test_receive_movable_result_c_left(self, mock: AsyncMock): self.assertEqual(len(mock.mock_calls), 2) # cursors - got = mock.mock_calls[0].args[0] + got = mock.mock_calls[1].args[0] self.assertEqual(type(got), Message) self.assertEqual(got.event, "multicast") # origin_event @@ -788,7 +822,7 @@ async def test_receive_movable_result_c_left(self, mock: AsyncMock): self.assertEqual(got.payload.cursors[1].color, self.cur_b.color) # moved - got = mock.mock_calls[1].args[0] + got = mock.mock_calls[0].args[0] self.assertEqual(type(got), Message) self.assertEqual(got.event, "multicast") # origin_event @@ -825,74 +859,78 @@ def tearDown(self): CursorHandler.watching = {} @patch("event.EventBroker.publish") - async def test_receive_tile_state_changed(self, mock: AsyncMock): + async def test_receive_flag_set(self, mock: AsyncMock): position = Point(-4, -3) - tile = Tile.from_int(0b00100111) # not open, flag, 7 + color = Color.BLUE + is_set = True - message: Message[TileStateChangedPayload] = Message( - event=InteractionEvent.TILE_STATE_CHANGED, - payload=TileStateChangedPayload( + message: Message[FlagSetPayload] = Message( + event=InteractionEvent.FLAG_SET, + payload=FlagSetPayload( position=position, - tile=tile + color=color, + is_set=is_set ) ) - await CursorEventHandler.receive_tile_state_changed(message) + await CursorEventHandler.receive_flag_set(message) - # tile-updated 발행 확인 + # flag-set 발행 확인 self.assertEqual(len(mock.mock_calls), 1) - # tile-updated - got: Message[TileUpdatedPayload] = mock.mock_calls[0].args[0] + # flag-set + got: Message[FlagSetPayload] = mock.mock_calls[0].args[0] self.assertEqual(type(got), Message) self.assertEqual(got.event, "multicast") # origin_event self.assertIn("origin_event", got.header) - self.assertEqual(got.header["origin_event"], InteractionEvent.TILE_UPDATED) + self.assertEqual(got.header["origin_event"], InteractionEvent.FLAG_SET) # target_conns 확인, [A, B] self.assertIn("target_conns", got.header) self.assertEqual(len(got.header["target_conns"]), 2) self.assertIn("A", got.header["target_conns"]) self.assertIn("B", got.header["target_conns"]) # payload 확인 - self.assertEqual(type(got.payload), TileUpdatedPayload) + self.assertEqual(type(got.payload), FlagSetPayload) self.assertEqual(got.payload.position, position) - self.assertEqual(got.payload.tile, tile.copy(hide_info=True)) + self.assertEqual(got.payload.color, color) + self.assertEqual(got.payload.is_set, is_set) @patch("event.EventBroker.publish") - async def test_receive_tile_state_changed_mine_boom(self, mock: AsyncMock): + async def test_receive_single_tile_open(self, mock: AsyncMock): position = Point(-4, -3) tile = Tile.from_int(0b11000000) # open, mine + tile_str = Tiles(data=bytearray([tile.data])).to_str() - message: Message[TileStateChangedPayload] = Message( - event=InteractionEvent.TILE_STATE_CHANGED, - payload=TileStateChangedPayload( + message: Message[SingleTileOpenedPayload] = Message( + event=InteractionEvent.SINGLE_TILE_OPENED, + payload=SingleTileOpenedPayload( position=position, - tile=tile + tile=tile_str ) ) - await CursorEventHandler.receive_tile_state_changed(message) + await CursorEventHandler.receive_single_tile_opened(message) - # tile-updated, you-died 발행 확인 + # single-tile-opened, you-died 발행 확인 self.assertEqual(len(mock.mock_calls), 2) - # tile-updated - got: Message[TileUpdatedPayload] = mock.mock_calls[0].args[0] + # single-tile-opened + got: Message[SingleTileOpenedPayload] = mock.mock_calls[0].args[0] self.assertEqual(type(got), Message) self.assertEqual(got.event, "multicast") # origin_event self.assertIn("origin_event", got.header) - self.assertEqual(got.header["origin_event"], InteractionEvent.TILE_UPDATED) + self.assertEqual(got.header["origin_event"], InteractionEvent.SINGLE_TILE_OPENED) # target_conns 확인, [A, B] self.assertIn("target_conns", got.header) self.assertEqual(len(got.header["target_conns"]), 2) self.assertIn("A", got.header["target_conns"]) self.assertIn("B", got.header["target_conns"]) # payload 확인 - self.assertEqual(type(got.payload), TileUpdatedPayload) + self.assertEqual(type(got.payload), SingleTileOpenedPayload) self.assertEqual(got.payload.position, position) - self.assertEqual(got.payload.tile.data, tile.data) + self.assertEqual(bytearray.fromhex(got.payload.tile)[0], tile.data) # you-died got: Message[YouDiedPayload] = mock.mock_calls[1].args[0] @@ -915,6 +953,45 @@ async def test_receive_tile_state_changed_mine_boom(self, mock: AsyncMock): # self.assertEqual(got.payload.revive_at, something) datetime.fromisoformat(got.payload.revive_at) + @patch("event.EventBroker.publish") + async def test_receive_tiles_opened(self, mock: AsyncMock): + start = Point(-3, 1) + end = Point(-2, 0) + tile_str = "1234123412341234" + + message: Message[TilesOpenedPayload] = Message( + event=InteractionEvent.TILES_OPENED, + payload=TilesOpenedPayload( + start_p=start, + end_p=end, + tiles=tile_str + ) + ) + + await CursorEventHandler.receive_tiles_opened(message) + + # tiles-opened 확인 + self.assertEqual(len(mock.mock_calls), 1) + + # tiles-opened + got: Message[TilesOpenedPayload] = mock.mock_calls[0].args[0] + self.assertEqual(type(got), Message) + self.assertEqual(got.event, "multicast") + # origin_event + self.assertIn("origin_event", got.header) + self.assertEqual(got.header["origin_event"], InteractionEvent.TILES_OPENED) + # target_conns 확인, [A, B, C] + self.assertIn("target_conns", got.header) + self.assertEqual(len(got.header["target_conns"]), 3) + self.assertIn("A", got.header["target_conns"]) + self.assertIn("B", got.header["target_conns"]) + self.assertIn("C", got.header["target_conns"]) + # payload 확인 + self.assertEqual(type(got.payload), TilesOpenedPayload) + self.assertEqual(got.payload.start_p, start) + self.assertEqual(got.payload.end_p, end) + self.assertEqual(got.payload.tiles, tile_str) + class CursorEventHandler_ConnClosed_TestCase(unittest.IsolatedAsyncioTestCase): def setUp(self): diff --git a/event/internal/event_broker.py b/event/internal/event_broker.py index 6301391..df78933 100644 --- a/event/internal/event_broker.py +++ b/event/internal/event_broker.py @@ -1,4 +1,5 @@ from __future__ import annotations +import asyncio from typing import Callable, Generic from message import Message from .exceptions import NoMatchingReceiverException @@ -78,10 +79,14 @@ async def publish(message: Message): if message.event not in EventBroker.event_dict: raise NoMatchingReceiverException(message.event) + coroutines = [] + receiver_ids = EventBroker.event_dict[message.event] for id in receiver_ids: receiver = Receiver.get_receiver(id) - await receiver(message) + coroutines.append(receiver(message)) + + await asyncio.gather(*coroutines) def _debug(message: Message): print(message.to_str(del_header=False)) diff --git a/message/payload/__init__.py b/message/payload/__init__.py index 6e9806d..02a15ae 100644 --- a/message/payload/__init__.py +++ b/message/payload/__init__.py @@ -1,9 +1,9 @@ from .internal.tiles_payload import FetchTilesPayload, TilesPayload, TilesEvent from .internal.base_payload import Payload from .internal.exceptions import InvalidFieldException, MissingFieldException, DumbHumanException -from .internal.new_conn_payload import NewConnPayload, NewConnEvent, CursorPayload, CursorsPayload, MyCursorPayload, ConnClosedPayload, CursorQuitPayload, SetViewSizePayload +from .internal.new_conn_payload import NewConnPayload, NewConnEvent, CursorPayload, CursorsPayload, MyCursorPayload, ConnClosedPayload, CursorQuitPayload, SetViewSizePayload, NewCursorCandidatePayload from .internal.parsable_payload import ParsablePayload from .internal.pointing_payload import PointerSetPayload, PointingResultPayload, PointingPayload, TryPointingPayload, PointEvent, ClickType from .internal.move_payload import MoveEvent, MovingPayload, MovedPayload, CheckMovablePayload, MovableResultPayload -from .internal.interaction_payload import TileStateChangedPayload, TileUpdatedPayload, YouDiedPayload, InteractionEvent +from .internal.interaction_payload import YouDiedPayload, InteractionEvent, SingleTileOpenedPayload, TilesOpenedPayload, FlagSetPayload from .internal.error_payload import ErrorEvent, ErrorPayload diff --git a/message/payload/internal/interaction_payload.py b/message/payload/internal/interaction_payload.py index 2568532..a45c000 100644 --- a/message/payload/internal/interaction_payload.py +++ b/message/payload/internal/interaction_payload.py @@ -1,14 +1,16 @@ from .base_payload import Payload from .parsable_payload import ParsablePayload from board.data import Point, Tile +from cursor.data import Color from dataclasses import dataclass from enum import Enum class InteractionEvent(str, Enum): YOU_DIED = "you-died" - TILE_UPDATED = "tile-updated" - TILE_STATE_CHANGED = "tile-state-changed" + SINGLE_TILE_OPENED = "single-tile-opened" + TILES_OPENED = "tiles-opened" + FLAG_SET = "flag-set" @dataclass @@ -17,12 +19,20 @@ class YouDiedPayload(Payload): @dataclass -class TileUpdatedPayload(Payload): +class SingleTileOpenedPayload(Payload): position: ParsablePayload[Point] - tile: ParsablePayload[Tile] + tile: str @dataclass -class TileStateChangedPayload(Payload): +class TilesOpenedPayload(Payload): + start_p: ParsablePayload[Point] + end_p: ParsablePayload[Point] + tiles: str + + +@dataclass +class FlagSetPayload(Payload): position: ParsablePayload[Point] - tile: ParsablePayload[Tile] + is_set: bool + color: Color | None diff --git a/message/payload/internal/new_conn_payload.py b/message/payload/internal/new_conn_payload.py index 5679512..8c12160 100644 --- a/message/payload/internal/new_conn_payload.py +++ b/message/payload/internal/new_conn_payload.py @@ -8,6 +8,7 @@ class NewConnEvent(str, Enum): NEW_CONN = "new-conn" + NEW_CURSOR_CANDIDATE = "new-cursor-candidate" CURSORS = "cursors" MY_CURSOR = "my-cursor" CONN_CLOSED = "conn-closed" @@ -22,6 +23,14 @@ class NewConnPayload(Payload): height: int +@dataclass +class NewCursorCandidatePayload(Payload): + conn_id: str + width: int + height: int + position: ParsablePayload[Point] + + @dataclass class CursorPayload(Payload): position: ParsablePayload[Point] diff --git a/server.py b/server.py index 6bc87f4..04d87ff 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,5 @@ from fastapi import FastAPI, WebSocket, Response, WebSocketDisconnect +from websockets.exceptions import ConnectionClosed from conn.manager import ConnectionManager from board.event.handler import BoardEventHandler from cursor.event.handler import CursorEventHandler @@ -32,7 +33,8 @@ async def session(ws: WebSocket): message = await conn.receive() message.header = {"sender": conn.id} await ConnectionManager.handle_message(message) - except WebSocketDisconnect as e: + except (WebSocketDisconnect, ConnectionClosed) as e: + # 연결 종료됨 break except Exception as e: msg = e