Update app.py
Browse files
app.py
CHANGED
@@ -106,14 +106,18 @@ class ArkanoidEnv(gym.Env):
|
|
106 |
self.bricks.remove(brick)
|
107 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
108 |
self.score += 1
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
111 |
self.done = True
|
|
|
|
|
|
|
112 |
|
113 |
-
if not self.bricks:
|
114 |
-
self.done = True
|
115 |
-
|
116 |
-
reward = 1 if self.score > 0 else -1
|
117 |
return self._get_state(), reward, self.done, {}
|
118 |
|
119 |
def _get_state(self):
|
@@ -142,16 +146,14 @@ class ArkanoidEnv(gym.Env):
|
|
142 |
pygame.quit()
|
143 |
|
144 |
# Training function
|
145 |
-
def train_model():
|
146 |
-
env = ArkanoidEnv()
|
147 |
model = DQN('MlpPolicy', env, verbose=1)
|
148 |
-
model.learn(total_timesteps=
|
149 |
model.save("arkanoid_model")
|
150 |
return model
|
151 |
|
152 |
# Evaluation function
|
153 |
-
def evaluate_model(model):
|
154 |
-
env = ArkanoidEnv()
|
155 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
156 |
return mean_reward
|
157 |
|
@@ -170,22 +172,33 @@ def play_game():
|
|
170 |
frames.append(gr.Image(value="frame.png"))
|
171 |
return frames
|
172 |
|
173 |
-
#
|
174 |
-
def
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
|
|
|
|
186 |
# Gradio interface
|
187 |
iface = gr.Interface(
|
188 |
-
fn=
|
189 |
inputs=None,
|
190 |
outputs="image",
|
191 |
live=True
|
|
|
106 |
self.bricks.remove(brick)
|
107 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
108 |
self.score += 1
|
109 |
+
reward = 1
|
110 |
+
if not self.bricks:
|
111 |
+
reward += 10 # Bonus reward for breaking all bricks
|
112 |
+
self.done = True
|
113 |
+
return self._get_state(), reward, self.done, {}
|
114 |
|
115 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
116 |
self.done = True
|
117 |
+
reward = -1
|
118 |
+
else:
|
119 |
+
reward = 0
|
120 |
|
|
|
|
|
|
|
|
|
121 |
return self._get_state(), reward, self.done, {}
|
122 |
|
123 |
def _get_state(self):
|
|
|
146 |
pygame.quit()
|
147 |
|
148 |
# Training function
|
149 |
+
def train_model(env, total_timesteps=10000):
|
|
|
150 |
model = DQN('MlpPolicy', env, verbose=1)
|
151 |
+
model.learn(total_timesteps=total_timesteps)
|
152 |
model.save("arkanoid_model")
|
153 |
return model
|
154 |
|
155 |
# Evaluation function
|
156 |
+
def evaluate_model(model, env):
|
|
|
157 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
158 |
return mean_reward
|
159 |
|
|
|
172 |
frames.append(gr.Image(value="frame.png"))
|
173 |
return frames
|
174 |
|
175 |
+
# Real-time training function
|
176 |
+
def train_and_play():
|
177 |
+
env = ArkanoidEnv()
|
178 |
+
model = DQN('MlpPolicy', env, verbose=1)
|
179 |
+
total_timesteps = 10000
|
180 |
+
timesteps_per_update = 1000
|
181 |
+
frames = []
|
182 |
|
183 |
+
for i in range(0, total_timesteps, timesteps_per_update):
|
184 |
+
model.learn(total_timesteps=timesteps_per_update)
|
185 |
+
obs = env.reset()[0]
|
186 |
+
done = False
|
187 |
+
episode_frames = []
|
188 |
+
while not done:
|
189 |
+
action, _states = model.predict(obs, deterministic=True)
|
190 |
+
obs, reward, done, info = env.step(action)
|
191 |
+
env.render()
|
192 |
+
pygame.image.save(screen, "frame.png")
|
193 |
+
episode_frames.append(gr.Image(value="frame.png"))
|
194 |
+
frames.extend(episode_frames)
|
195 |
+
yield frames
|
196 |
|
197 |
+
# Main function
|
198 |
+
def main():
|
199 |
# Gradio interface
|
200 |
iface = gr.Interface(
|
201 |
+
fn=train_and_play,
|
202 |
inputs=None,
|
203 |
outputs="image",
|
204 |
live=True
|