Ivan000 commited on
Commit
ca7808d
·
verified ·
1 Parent(s): 3c22597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -7,9 +7,8 @@ import os
7
  import numpy as np
8
  import pygame
9
  import random
 
10
  from stable_baselines3 import DQN
11
- from stable_baselines3.common.env_util import make_atari_env
12
- from stable_baselines3.common.vec_env import VecFrameStack
13
  from stable_baselines3.common.evaluation import evaluate_policy
14
  import gradio as gr
15
 
@@ -69,14 +68,12 @@ class Brick:
69
  def __init__(self, x, y):
70
  self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
71
 
72
- class ArkanoidEnv:
73
  def __init__(self):
74
- self.paddle = Paddle()
75
- self.ball = Ball()
76
- self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
77
- self.clock = pygame.time.Clock()
78
- self.done = False
79
- self.score = 0
80
 
81
  def reset(self):
82
  self.paddle = Paddle()
@@ -87,7 +84,13 @@ class ArkanoidEnv:
87
  return self._get_state()
88
 
89
  def step(self, action):
90
- self.paddle.move(action)
 
 
 
 
 
 
91
  self.ball.move()
92
 
93
  if self.ball.rect.colliderect(self.paddle.rect):
@@ -118,16 +121,20 @@ class ArkanoidEnv:
118
  ]
119
  for brick in self.bricks:
120
  state.extend([brick.rect.x, brick.rect.y])
 
121
  return np.array(state, dtype=np.float32)
122
 
123
- def render(self):
124
  screen.fill(BLACK)
125
  pygame.draw.rect(screen, WHITE, self.paddle.rect)
126
  pygame.draw.ellipse(screen, WHITE, self.ball.rect)
127
  for brick in self.bricks:
128
  pygame.draw.rect(screen, RED, brick.rect)
129
  pygame.display.flip()
130
- self.clock.tick(FPS)
 
 
 
131
 
132
  # Training function
133
  def train_model():
@@ -152,7 +159,7 @@ def play_game():
152
  frames = []
153
  while not done:
154
  action, _states = model.predict(obs, deterministic=True)
155
- obs, rewards, done, info = env.step(action)
156
  env.render()
157
  pygame.image.save(screen, "frame.png")
158
  frames.append(gr.Image(value="frame.png"))
@@ -190,6 +197,7 @@ if __name__ == "__main__":
190
  # - stable-baselines3
191
  # - torch
192
  # - gradio
 
193
  #
194
  # You can install these dependencies using pip:
195
- # pip install pygame stable-baselines3 torch gradio
 
7
  import numpy as np
8
  import pygame
9
  import random
10
+ import gymnasium as gym
11
  from stable_baselines3 import DQN
 
 
12
  from stable_baselines3.common.evaluation import evaluate_policy
13
  import gradio as gr
14
 
 
68
  def __init__(self, x, y):
69
  self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
70
 
71
+ class ArkanoidEnv(gym.Env):
72
  def __init__(self):
73
+ super(ArkanoidEnv, self).__init__()
74
+ self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
75
+ self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
76
+ self.reset()
 
 
77
 
78
  def reset(self):
79
  self.paddle = Paddle()
 
84
  return self._get_state()
85
 
86
  def step(self, action):
87
+ if action == 0:
88
+ self.paddle.move(0)
89
+ elif action == 1:
90
+ self.paddle.move(-1)
91
+ elif action == 2:
92
+ self.paddle.move(1)
93
+
94
  self.ball.move()
95
 
96
  if self.ball.rect.colliderect(self.paddle.rect):
 
121
  ]
122
  for brick in self.bricks:
123
  state.extend([brick.rect.x, brick.rect.y])
124
+ state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
125
  return np.array(state, dtype=np.float32)
126
 
127
+ def render(self, mode='human'):
128
  screen.fill(BLACK)
129
  pygame.draw.rect(screen, WHITE, self.paddle.rect)
130
  pygame.draw.ellipse(screen, WHITE, self.ball.rect)
131
  for brick in self.bricks:
132
  pygame.draw.rect(screen, RED, brick.rect)
133
  pygame.display.flip()
134
+ pygame.time.Clock().tick(FPS)
135
+
136
+ def close(self):
137
+ pygame.quit()
138
 
139
  # Training function
140
  def train_model():
 
159
  frames = []
160
  while not done:
161
  action, _states = model.predict(obs, deterministic=True)
162
+ obs, reward, done, info = env.step(action)
163
  env.render()
164
  pygame.image.save(screen, "frame.png")
165
  frames.append(gr.Image(value="frame.png"))
 
197
  # - stable-baselines3
198
  # - torch
199
  # - gradio
200
+ # - gymnasium
201
  #
202
  # You can install these dependencies using pip:
203
+ # pip install pygame stable-baselines3 torch gradio gymnasium