Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,8 @@ import os
|
|
7 |
import numpy as np
|
8 |
import pygame
|
9 |
import random
|
|
|
10 |
from stable_baselines3 import DQN
|
11 |
-
from stable_baselines3.common.env_util import make_atari_env
|
12 |
-
from stable_baselines3.common.vec_env import VecFrameStack
|
13 |
from stable_baselines3.common.evaluation import evaluate_policy
|
14 |
import gradio as gr
|
15 |
|
@@ -69,14 +68,12 @@ class Brick:
|
|
69 |
def __init__(self, x, y):
|
70 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
|
71 |
|
72 |
-
class ArkanoidEnv:
|
73 |
def __init__(self):
|
74 |
-
self.
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
self.
|
78 |
-
self.done = False
|
79 |
-
self.score = 0
|
80 |
|
81 |
def reset(self):
|
82 |
self.paddle = Paddle()
|
@@ -87,7 +84,13 @@ class ArkanoidEnv:
|
|
87 |
return self._get_state()
|
88 |
|
89 |
def step(self, action):
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
self.ball.move()
|
92 |
|
93 |
if self.ball.rect.colliderect(self.paddle.rect):
|
@@ -118,16 +121,20 @@ class ArkanoidEnv:
|
|
118 |
]
|
119 |
for brick in self.bricks:
|
120 |
state.extend([brick.rect.x, brick.rect.y])
|
|
|
121 |
return np.array(state, dtype=np.float32)
|
122 |
|
123 |
-
def render(self):
|
124 |
screen.fill(BLACK)
|
125 |
pygame.draw.rect(screen, WHITE, self.paddle.rect)
|
126 |
pygame.draw.ellipse(screen, WHITE, self.ball.rect)
|
127 |
for brick in self.bricks:
|
128 |
pygame.draw.rect(screen, RED, brick.rect)
|
129 |
pygame.display.flip()
|
130 |
-
|
|
|
|
|
|
|
131 |
|
132 |
# Training function
|
133 |
def train_model():
|
@@ -152,7 +159,7 @@ def play_game():
|
|
152 |
frames = []
|
153 |
while not done:
|
154 |
action, _states = model.predict(obs, deterministic=True)
|
155 |
-
obs,
|
156 |
env.render()
|
157 |
pygame.image.save(screen, "frame.png")
|
158 |
frames.append(gr.Image(value="frame.png"))
|
@@ -190,6 +197,7 @@ if __name__ == "__main__":
|
|
190 |
# - stable-baselines3
|
191 |
# - torch
|
192 |
# - gradio
|
|
|
193 |
#
|
194 |
# You can install these dependencies using pip:
|
195 |
-
# pip install pygame stable-baselines3 torch gradio
|
|
|
7 |
import numpy as np
|
8 |
import pygame
|
9 |
import random
|
10 |
+
import gymnasium as gym
|
11 |
from stable_baselines3 import DQN
|
|
|
|
|
12 |
from stable_baselines3.common.evaluation import evaluate_policy
|
13 |
import gradio as gr
|
14 |
|
|
|
68 |
def __init__(self, x, y):
|
69 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
|
70 |
|
71 |
+
class ArkanoidEnv(gym.Env):
|
72 |
def __init__(self):
|
73 |
+
super(ArkanoidEnv, self).__init__()
|
74 |
+
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
|
75 |
+
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
|
76 |
+
self.reset()
|
|
|
|
|
77 |
|
78 |
def reset(self):
|
79 |
self.paddle = Paddle()
|
|
|
84 |
return self._get_state()
|
85 |
|
86 |
def step(self, action):
|
87 |
+
if action == 0:
|
88 |
+
self.paddle.move(0)
|
89 |
+
elif action == 1:
|
90 |
+
self.paddle.move(-1)
|
91 |
+
elif action == 2:
|
92 |
+
self.paddle.move(1)
|
93 |
+
|
94 |
self.ball.move()
|
95 |
|
96 |
if self.ball.rect.colliderect(self.paddle.rect):
|
|
|
121 |
]
|
122 |
for brick in self.bricks:
|
123 |
state.extend([brick.rect.x, brick.rect.y])
|
124 |
+
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
|
125 |
return np.array(state, dtype=np.float32)
|
126 |
|
127 |
+
def render(self, mode='human'):
|
128 |
screen.fill(BLACK)
|
129 |
pygame.draw.rect(screen, WHITE, self.paddle.rect)
|
130 |
pygame.draw.ellipse(screen, WHITE, self.ball.rect)
|
131 |
for brick in self.bricks:
|
132 |
pygame.draw.rect(screen, RED, brick.rect)
|
133 |
pygame.display.flip()
|
134 |
+
pygame.time.Clock().tick(FPS)
|
135 |
+
|
136 |
+
def close(self):
|
137 |
+
pygame.quit()
|
138 |
|
139 |
# Training function
|
140 |
def train_model():
|
|
|
159 |
frames = []
|
160 |
while not done:
|
161 |
action, _states = model.predict(obs, deterministic=True)
|
162 |
+
obs, reward, done, info = env.step(action)
|
163 |
env.render()
|
164 |
pygame.image.save(screen, "frame.png")
|
165 |
frames.append(gr.Image(value="frame.png"))
|
|
|
197 |
# - stable-baselines3
|
198 |
# - torch
|
199 |
# - gradio
|
200 |
+
# - gymnasium
|
201 |
#
|
202 |
# You can install these dependencies using pip:
|
203 |
+
# pip install pygame stable-baselines3 torch gradio gymnasium
|