Ivan000 commited on
Commit
0c990cc
·
verified ·
1 Parent(s): d3b8fe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -21
app.py CHANGED
@@ -106,14 +106,18 @@ class ArkanoidEnv(gym.Env):
106
  self.bricks.remove(brick)
107
  self.ball.velocity[1] = -self.ball.velocity[1]
108
  self.score += 1
 
 
 
 
 
109
 
110
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
111
  self.done = True
 
 
 
112
 
113
- if not self.bricks:
114
- self.done = True
115
-
116
- reward = 1 if self.score > 0 else -1
117
  return self._get_state(), reward, self.done, {}
118
 
119
  def _get_state(self):
@@ -142,16 +146,14 @@ class ArkanoidEnv(gym.Env):
142
  pygame.quit()
143
 
144
  # Training function
145
- def train_model():
146
- env = ArkanoidEnv()
147
  model = DQN('MlpPolicy', env, verbose=1)
148
- model.learn(total_timesteps=10000)
149
  model.save("arkanoid_model")
150
  return model
151
 
152
  # Evaluation function
153
- def evaluate_model(model):
154
- env = ArkanoidEnv()
155
  mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
156
  return mean_reward
157
 
@@ -170,22 +172,33 @@ def play_game():
170
  frames.append(gr.Image(value="frame.png"))
171
  return frames
172
 
173
- # Main function
174
- def main():
175
- if not os.path.exists("arkanoid_model.zip"):
176
- print("Training model...")
177
- train_model()
178
- else:
179
- print("Model already trained.")
180
 
181
- print("Evaluating model...")
182
- model = DQN.load("arkanoid_model")
183
- mean_reward = evaluate_model(model)
184
- print(f"Mean reward: {mean_reward}")
 
 
 
 
 
 
 
 
 
185
 
 
 
186
  # Gradio interface
187
  iface = gr.Interface(
188
- fn=play_game,
189
  inputs=None,
190
  outputs="image",
191
  live=True
 
106
  self.bricks.remove(brick)
107
  self.ball.velocity[1] = -self.ball.velocity[1]
108
  self.score += 1
109
+ reward = 1
110
+ if not self.bricks:
111
+ reward += 10 # Bonus reward for breaking all bricks
112
+ self.done = True
113
+ return self._get_state(), reward, self.done, {}
114
 
115
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
116
  self.done = True
117
+ reward = -1
118
+ else:
119
+ reward = 0
120
 
 
 
 
 
121
  return self._get_state(), reward, self.done, {}
122
 
123
  def _get_state(self):
 
146
  pygame.quit()
147
 
148
  # Training function
149
+ def train_model(env, total_timesteps=10000):
 
150
  model = DQN('MlpPolicy', env, verbose=1)
151
+ model.learn(total_timesteps=total_timesteps)
152
  model.save("arkanoid_model")
153
  return model
154
 
155
  # Evaluation function
156
+ def evaluate_model(model, env):
 
157
  mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
158
  return mean_reward
159
 
 
172
  frames.append(gr.Image(value="frame.png"))
173
  return frames
174
 
175
+ # Real-time training function
176
+ def train_and_play():
177
+ env = ArkanoidEnv()
178
+ model = DQN('MlpPolicy', env, verbose=1)
179
+ total_timesteps = 10000
180
+ timesteps_per_update = 1000
181
+ frames = []
182
 
183
+ for i in range(0, total_timesteps, timesteps_per_update):
184
+ model.learn(total_timesteps=timesteps_per_update)
185
+ obs = env.reset()[0]
186
+ done = False
187
+ episode_frames = []
188
+ while not done:
189
+ action, _states = model.predict(obs, deterministic=True)
190
+ obs, reward, done, info = env.step(action)
191
+ env.render()
192
+ pygame.image.save(screen, "frame.png")
193
+ episode_frames.append(gr.Image(value="frame.png"))
194
+ frames.extend(episode_frames)
195
+ yield frames
196
 
197
+ # Main function
198
+ def main():
199
  # Gradio interface
200
  iface = gr.Interface(
201
+ fn=train_and_play,
202
  inputs=None,
203
  outputs="image",
204
  live=True