Ivan000 commited on
Commit
c2ea5f9
·
verified ·
1 Parent(s): 995402e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -65,11 +65,13 @@ class Brick:
65
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
66
 
67
  class ArkanoidEnv(gym.Env):
68
- def __init__(self):
69
  super(ArkanoidEnv, self).__init__()
70
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
71
  self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
72
  self.seed_value = None
 
 
73
  self.reset()
74
 
75
  def reset(self, seed=None, options=None):
@@ -103,16 +105,16 @@ class ArkanoidEnv(gym.Env):
103
  self.bricks.remove(brick)
104
  self.ball.velocity[1] = -self.ball.velocity[1]
105
  self.score += 1
106
- reward = 1
107
  if not self.bricks:
108
- reward += 10 # Bonus reward for breaking all bricks
109
  self.done = True
110
  truncated = False
111
  return self._get_state(), reward, self.done, truncated, {}
112
 
113
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
114
  self.done = True
115
- reward = -1
116
  truncated = False
117
  else:
118
  reward = 0
@@ -145,27 +147,14 @@ class ArkanoidEnv(gym.Env):
145
  def close(self):
146
  pygame.quit()
147
 
148
- # Training function
149
- def train_model(env, total_timesteps=10000):
150
- model = DQN('MlpPolicy', env, verbose=1)
151
- model.learn(total_timesteps=total_timesteps)
152
- model.save("arkanoid_model")
153
- return model
154
-
155
- # Evaluation function
156
- def evaluate_model(model, env):
157
- mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
158
- return mean_reward
159
-
160
- # Real-time training function
161
- def train_and_play():
162
- env = ArkanoidEnv()
163
  model = DQN('MlpPolicy', env, verbose=1)
164
- total_timesteps = 10000
165
  timesteps_per_update = 1000
166
  video_frames = []
167
 
168
- for i in range(0, total_timesteps, timesteps_per_update):
169
  model.learn(total_timesteps=timesteps_per_update)
170
  obs, _ = env.reset()
171
  done = False
@@ -192,10 +181,14 @@ def train_and_play():
192
 
193
  # Main function
194
  def main():
195
- # Gradio interface
196
  iface = gr.Interface(
197
  fn=train_and_play,
198
- inputs=None,
 
 
 
 
199
  outputs="video",
200
  live=True
201
  )
 
65
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
66
 
67
  class ArkanoidEnv(gym.Env):
68
+ def __init__(self, reward_size=1, penalty_size=-1):
69
  super(ArkanoidEnv, self).__init__()
70
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
71
  self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
72
  self.seed_value = None
73
+ self.reward_size = reward_size
74
+ self.penalty_size = penalty_size
75
  self.reset()
76
 
77
  def reset(self, seed=None, options=None):
 
105
  self.bricks.remove(brick)
106
  self.ball.velocity[1] = -self.ball.velocity[1]
107
  self.score += 1
108
+ reward = self.reward_size
109
  if not self.bricks:
110
+ reward += self.reward_size * 10 # Bonus reward for breaking all bricks
111
  self.done = True
112
  truncated = False
113
  return self._get_state(), reward, self.done, truncated, {}
114
 
115
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
116
  self.done = True
117
+ reward = self.penalty_size
118
  truncated = False
119
  else:
120
  reward = 0
 
147
  def close(self):
148
  pygame.quit()
149
 
150
+ # Training and playing with custom parameters
151
+ def train_and_play(reward_size, penalty_size, iterations):
152
+ env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size)
 
 
 
 
 
 
 
 
 
 
 
 
153
  model = DQN('MlpPolicy', env, verbose=1)
 
154
  timesteps_per_update = 1000
155
  video_frames = []
156
 
157
+ for i in range(0, iterations, timesteps_per_update):
158
  model.learn(total_timesteps=timesteps_per_update)
159
  obs, _ = env.reset()
160
  done = False
 
181
 
182
  # Main function
183
  def main():
184
+ # Gradio interface with parameters
185
  iface = gr.Interface(
186
  fn=train_and_play,
187
+ inputs=[
188
+ gr.Number(label="Reward Size", value=1),
189
+ gr.Number(label="Penalty Size", value=-1),
190
+ gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
191
+ ],
192
  outputs="video",
193
  live=True
194
  )