-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenviornement.py
311 lines (277 loc) · 13 KB
/
enviornement.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import numpy as np
import pygame
import gymnasium as gym
from gymnasium import spaces
import random
from typing import List
class MineSweeper(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
sprites = pygame.image.load("sprites/minesweeper.png")
sprites = pygame.transform.scale(sprites,(128,128))
def getSprite(self,spriteSheet,w,h,x,y):
sprite = pygame.Surface((w,h)).convert_alpha()
sprite.blit(spriteSheet,(0,0),(w * x, h * y,w,h))
return sprite
def drawTile(self,x,y,id):
match id:
case -2:
self.screen.blit(self.spriteEmpty2,(x*32,y*32))
case -1:
self.screen.blit(self.spriteEmpty1,(x*32,y*32))
case 0:
self.screen.blit(self.sprite0,(x*32,y*32))
case 1:
self.screen.blit(self.sprite1,(x*32,y*32))
case 2:
self.screen.blit(self.sprite2,(x*32,y*32))
case 3:
self.screen.blit(self.sprite3,(x*32,y*32))
case 4:
self.screen.blit(self.sprite4,(x*32,y*32))
case 5:
self.screen.blit(self.sprite5,(x*32,y*32))
case 6:
self.screen.blit(self.sprite6,(x*32,y*32))
case 7:
self.screen.blit(self.sprite7,(x*32,y*32))
case 8:
self.screen.blit(self.sprite8,(x*32,y*32))
case 9:
self.screen.blit(self.spriteBomb,(x*32,y*32))
def automaticUncover(self,x,y):
totalReward = 0
if self.hiddenChart[y][x] != 9 and self.chart[y][x] < 0:
self.chart[y][x] = self.hiddenChart[y][x]
totalReward += 1
if self.hiddenChart[y][x] == 0:
if x + 1 < self.TILE_X_AMOUNT:
totalReward += self.automaticUncover(x+1,y)
if x - 1 >= 0:
totalReward += self.automaticUncover(x-1,y)
if y + 1 < self.TILE_Y_AMOUNT:
totalReward += self.automaticUncover(x,y+1)
if y - 1 >= 0:
totalReward += self.automaticUncover(x,y-1)
if x + 1 < self.TILE_X_AMOUNT and y + 1 < self.TILE_Y_AMOUNT:
totalReward += self.automaticUncover(x+1,y+1)
if x + 1 < self.TILE_X_AMOUNT and y - 1 >= 0:
totalReward += self.automaticUncover(x+1,y-1)
if x - 1 >= 0 and y - 1 >= 0:
totalReward += self.automaticUncover(x-1,y-1)
if x - 1 >= 0 and y + 1 < self.TILE_Y_AMOUNT:
totalReward += self.automaticUncover(x-1,y+1)
return totalReward
def pickTile(self,x,y):
if self.firstMove == True:
#Get all cover tiles outside safety square
safeTiles = []
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
if abs(i - y) <= 1 and abs(j - x) <= 1:
continue
safeTiles.insert(0,(j,i))
#Place bombs
placedBombs = 0
while placedBombs != self.BOMB_AMOUNT:
k = random.randint(0,len(safeTiles)-1)
self.hiddenChart[safeTiles[k][1]][safeTiles[k][0]] = 9
del safeTiles[k]
placedBombs += 1
#Place numbers near bombs
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
if self.hiddenChart[i][j] != 9:
continue
if i - 1 >= 0 and j - 1 >= 0 and self.hiddenChart[i-1][j-1] != 9:
self.hiddenChart[i-1][j-1] += 1
if i - 1 >= 0 and self.hiddenChart[i-1][j] !=9:
self.hiddenChart[i-1][j] += 1
if i - 1 >= 0 and j + 1 < self.TILE_X_AMOUNT and self.hiddenChart[i-1][j+1] != 9:
self.hiddenChart[i-1][j+1] += 1
if j - 1 >= 0 and self.hiddenChart[i][j-1] != 9:
self.hiddenChart[i][j-1] += 1
if j + 1 < self.TILE_X_AMOUNT and self.hiddenChart[i][j+1] != 9:
self.hiddenChart[i][j+1] += 1
if i + 1 < self.TILE_Y_AMOUNT and j - 1 >= 0 and self.hiddenChart[i+1][j-1] != 9:
self.hiddenChart[i+1][j-1] += 1
if i + 1 < self.TILE_Y_AMOUNT and self.hiddenChart[i+1][j] !=9:
self.hiddenChart[i+1][j] += 1
if i + 1 < self.TILE_Y_AMOUNT and j + 1 < self.TILE_X_AMOUNT and self.hiddenChart[i+1][j+1] != 9:
self.hiddenChart[i+1][j+1] += 1
if self.hiddenChart[y][x] == 9:
return -1
totalReward = 0
#check if picked uncovered tile
if self.chart[y][x] == self.hiddenChart[y][x]:
return -2
#automaticly uncover tiles that are not next to bombs
if self.hiddenChart[y][x] == 0:
totalReward += self.automaticUncover(x,y)
else:
self.chart[y][x] = self.hiddenChart[y][x]
totalReward += 1
self.firstMove = False
return totalReward
def __init__(self, render_mode=None, sizeX=20,sizeY=20,bombs=80):
super(MineSweeper, self).__init__()
self.render_mode = render_mode
self.TILE_X_AMOUNT = sizeX
self.TILE_Y_AMOUNT = sizeY
self.BOMB_AMOUNT = bombs
self.WINNING_SCORE = (self.TILE_X_AMOUNT * self.TILE_Y_AMOUNT) - self.BOMB_AMOUNT
#SCREEN_WIDTH = TILE_X_AMOUNT * 32
#SCREEN_HEIGHT = TILE_Y_AMOUNT * 32
if (self.TILE_X_AMOUNT * self.TILE_Y_AMOUNT) - self.BOMB_AMOUNT < 9:
print("Error: Unable to guarantee first pick as safe")
return None
#Using game screen: Watch game screen
#self.observation_space = spaces.Box(low=0,high=255,shape=(SCREEN_WIDTH, SCREEN_HEIGHT, 3),dtype=np.uint8)
#Using game logic: Watch grid 0..8 = safe tiles, 9 = bomb, -1 = covered tile
self.observation_space = spaces.Box(low=-1,high=9,shape=(self.TILE_Y_AMOUNT, self.TILE_X_AMOUNT),dtype=np.int8)
# Choose any tile on 20x20 grid
#Action space will shrink overtime, policy must choose only viable tiles/actions (Policy cannot uncover uncovered tile)
self.action_space = spaces.Discrete(self.TILE_Y_AMOUNT*self.TILE_X_AMOUNT)
self.possible_actions = []
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
self.possible_actions.append((i*self.TILE_X_AMOUNT)+j)
self.invalid_actions = []
self.chart = np.zeros((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT))
self.hiddenChart = np.zeros((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT))
self.grid = np.empty((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT),dtype=pygame.Rect)
self.score = 0
self.screen = None
self.clock = None
self.reset()
def reset(self, seed=None, options=None):
#return info message
info = {
"state": "Lost",
"score": str(self.score) + "/" + str(self.WINNING_SCORE)
}
if self.score == self.WINNING_SCORE:
info["state"] = "Won"
#initialize agent and reset enviornment
self.chart = np.zeros((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT))
self.hiddenChart = np.zeros((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT))
self.grid = np.empty((self.TILE_Y_AMOUNT,self.TILE_X_AMOUNT),dtype=pygame.Rect)
self.invalid_actions = []
#create covered tiles
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
self.chart[i][j] = -1
self.grid[i][j] = pygame.Rect(j * 32,i*32,32,32).copy()
self.score = 0
self.firstMove = True
return np.array(self.chart).astype(np.int8), info
def decode_action_y(self,action):
y = int(action / self.TILE_Y_AMOUNT)
return y
def decode_action_x(self,action):
x = action % self.TILE_X_AMOUNT
return x
#sample returns y,x picktile needs x y
def step(self, action):
#return info message
info = {
"state": "Playing",
"score": str(self.score) + "/" + str(self.WINNING_SCORE)
}
uncoveredtiles = self.pickTile(self.decode_action_x(action),self.decode_action_y(action))
# if uncoveredtiles >= 1:
# reward = 3
reward = uncoveredtiles
#reward = 0
self.update_invalid_actions()
terminated = False
if uncoveredtiles == -1: #clicked tile with bomb
self.revealChart()
reward = -200
terminated = True
elif uncoveredtiles == -2: #clicked uncovered tile
reward = -1
else: #clicked safe tile
self.score += uncoveredtiles
if self.score == self.WINNING_SCORE:
reward += 0
terminated = True
# if self.guessed(self.decode_action_x(action),self.decode_action_y(action)):
# reward -= 0.3
# else:
# reward += 0.3
return np.int8(self.chart), reward, terminated, False, info
def render_frame_human(self):
if self.screen is None and self.render_mode == "human":
pygame.init()
self.screen = pygame.display.set_mode((self.TILE_X_AMOUNT * 32,self.TILE_Y_AMOUNT * 32))
pygame.display.set_caption('Minesweeper')
self.sprites = pygame.image.load("sprites/minesweeper.png")
self.sprites = pygame.transform.scale(self.sprites,(128,128))
self.spriteEmpty1 = self.getSprite(self.sprites,32,32,0,0)
self.spriteEmpty2 = self.getSprite(self.sprites,32,32,1,0)
self.sprite0 = self.getSprite(self.sprites,32,32,3,3)
self.sprite1 = self.getSprite(self.sprites,32,32,0,1)
self.sprite2 = self.getSprite(self.sprites,32,32,1,1)
self.sprite3 = self.getSprite(self.sprites,32,32,2,1)
self.sprite4 = self.getSprite(self.sprites,32,32,3,1)
self.sprite5 = self.getSprite(self.sprites,32,32,0,2)
self.sprite6 = self.getSprite(self.sprites,32,32,1,2)
self.sprite7 = self.getSprite(self.sprites,32,32,2,2)
self.sprite8 = self.getSprite(self.sprites,32,32,3,2)
self.spriteBomb = self.getSprite(self.sprites,32,32,0,3)
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()
self.screen.fill((0,0,0))
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
self.drawTile(j,i,self.chart[i][j])
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
def render(self):
match self.render_mode:
case "human":
return self.render_frame_human()
case "console":
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
if self.chart[i][j] < 0:
print("?",end="")
else:
print(int(self.chart[i][j]),end="")
print("\n",end="")
def close(self):
pygame.quit()
exit()
def update_invalid_actions(self):
self.invalid_actions = []
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
if self.chart[i][j] >= 0:
self.invalid_actions.append((i*self.TILE_X_AMOUNT)+j)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
def revealChart(self):
for i in range(0,self.TILE_Y_AMOUNT):
for j in range (0,self.TILE_X_AMOUNT):
self.chart[i][j] = self.hiddenChart[i][j]
#check if agent gussed (picked tile is not near uncovered tiles)
def guessed(self,x,y):
if x + 1 < self.TILE_X_AMOUNT and self.chart[y][x+1] != -1:
return False
if x - 1 >= 0 and self.chart[y][x-1] != -1:
return False
if y + 1 < self.TILE_Y_AMOUNT and self.chart[y+1][x] != -1:
return False
if y - 1 >= 0 and self.chart[y-1][x] != -1:
return False
if x + 1 < self.TILE_X_AMOUNT and y + 1 < self.TILE_Y_AMOUNT and self.chart[y+1][x+1] != -1:
return False
if x + 1 < self.TILE_X_AMOUNT and y - 1 >= 0 and self.chart[y-1][x+1] != -1:
return False
if x - 1 >= 0 and y - 1 >= 0 and self.chart[y-1][x-1] != -1:
return False
if x - 1 >= 0 and y + 1 < self.TILE_Y_AMOUNT and self.chart[y+1][x-1] != -1:
return False
return True
#https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/docs/modules/ppo_mask.rst
#https://www.gymlibrary.dev/content/environment_creation/