File size: 6,895 Bytes
3c22597
375aee6
3c22597
 
 
ca7808d
3c22597
68621da
3c22597
a5cda12
3c22597
 
 
 
 
 
 
 
 
 
 
40ec26a
3c22597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995402e
3c22597
ca7808d
927c930
ca7808d
 
927c930
c2ea5f9
 
8279e35
ca7808d
3c22597
d3b8fe9
927c930
 
 
965d906
3c22597
68621da
995402e
3c22597
 
d3b8fe9
3c22597
 
ca7808d
 
 
 
 
 
 
3c22597
 
 
 
68621da
3c22597
 
 
 
 
 
927c930
0c990cc
927c930
0c990cc
927c930
 
3c22597
 
 
927c930
 
 
 
 
3c22597
927c930
3c22597
 
927c930
 
 
 
 
 
 
 
 
 
 
3c22597
68621da
 
 
 
 
3c22597
68621da
 
 
 
 
 
 
ca7808d
 
 
3c22597
927c930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a14aa44
927c930
a14aa44
 
927c930
0c990cc
3c22597
0c990cc
c2ea5f9
 
 
68621da
c2ea5f9
 
927c930
 
3c22597
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
os.environ["XDG_RUNTIME_DIR"] = "/tmp"
import numpy as np
import pygame
import random
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
import gradio as gr
import cv2

# Constants
SCREEN_WIDTH = 640
SCREEN_HEIGHT = 480
PADDLE_WIDTH = 100
PADDLE_HEIGHT = 10
BALL_RADIUS = 10
BRICK_WIDTH = 60
BRICK_HEIGHT = 20
BRICK_ROWS = 5
BRICK_COLS = 10
FPS = 40

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)

# Initialize Pygame
pygame.init()

# Game classes
class Paddle:
    def __init__(self):
        self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)

    def move(self, direction):
        if direction == -1:
            self.rect.x -= 10
        elif direction == 1:
            self.rect.x += 10
        self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))

class Ball:
    def __init__(self):
        self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
        self.velocity = [random.choice([-5, 5]), -5]

    def move(self):
        self.rect.x += self.velocity[0]
        self.rect.y += self.velocity[1]

        if self.rect.left <= 0 or self.rect.right >= SCREEN_WIDTH:
            self.velocity[0] = -self.velocity[0]
        if self.rect.top <= 0:
            self.velocity[1] = -self.velocity[1]

    def reset(self):
        self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
        self.velocity = [random.choice([-5, 5]), -5]

class Brick:
    def __init__(self, x, y):
        self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)

class ArkanoidEnv(gym.Env):
    def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
        super(ArkanoidEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(3)  # 0: stay, 1: move left, 2: move right
        self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
        self.reward_size = reward_size
        self.penalty_size = penalty_size
        self.platform_reward = platform_reward
        self.reset()

    def reset(self, seed=None, options=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        self.paddle = Paddle()
        self.ball = Ball()
        self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) 
                       for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
        self.done = False
        self.score = 0
        return self._get_state(), {}

    def step(self, action):
        if action == 0:
            self.paddle.move(0)
        elif action == 1:
            self.paddle.move(-1)
        elif action == 2:
            self.paddle.move(1)

        self.ball.move()

        if self.ball.rect.colliderect(self.paddle.rect):
            self.ball.velocity[1] = -self.ball.velocity[1]
            self.score += self.platform_reward

        for brick in self.bricks[:]:
            if self.ball.rect.colliderect(brick.rect):
                self.bricks.remove(brick)
                self.ball.velocity[1] = -self.ball.velocity[1]
                self.score += 1
                reward = self.reward_size
                if not self.bricks:
                    reward += self.reward_size * 10  # Bonus reward for breaking all bricks
                    self.done = True
                    truncated = False
                    return self._get_state(), reward, self.done, truncated, {}

        if self.ball.rect.bottom >= SCREEN_HEIGHT:
            self.done = True
            reward = self.penalty_size
            truncated = False
        else:
            reward = 0
            truncated = False

        return self._get_state(), reward, self.done, truncated, {}

    def _get_state(self):
        state = [
            self.paddle.rect.x,
            self.ball.rect.x,
            self.ball.rect.y,
            self.ball.velocity[0],
            self.ball.velocity[1]
        ]
        for brick in self.bricks:
            state.extend([brick.rect.x, brick.rect.y])
        state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks)))  # Padding for missing bricks
        return np.array(state, dtype=np.float32)

    def render(self, mode='rgb_array'):
        surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
        surface.fill(BLACK)
        pygame.draw.rect(surface, WHITE, self.paddle.rect)
        pygame.draw.ellipse(surface, WHITE, self.ball.rect)
        for brick in self.bricks:
            pygame.draw.rect(surface, RED, brick.rect)

        if mode == 'rgb_array':
            return pygame.surfarray.array3d(surface)
        elif mode == 'human':
            pygame.display.get_surface().blit(surface, (0, 0))
            pygame.display.flip()

    def close(self):
        pygame.quit()

# Training and playing with custom parameters
def train_and_play(reward_size, penalty_size, platform_reward, iterations):
    env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
    model = DQN('MlpPolicy', env, verbose=1)
    timesteps_per_update = min(1000, iterations)
    video_frames = []

    completed_iterations = 0
    while completed_iterations < iterations:
        steps = min(timesteps_per_update, iterations - completed_iterations)
        model.learn(total_timesteps=steps)
        completed_iterations += steps

        obs, _ = env.reset()
        done = False
        while not done:
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, done, truncated, _ = env.step(action)

            frame = env.render(mode='rgb_array')
            frame = np.rot90(frame)
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            video_frames.append(frame)

    video_path = "arkanoid_training.mp4"
    video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
    for frame in video_frames:
        video_writer.write(frame)
    video_writer.release()

    env.close()
    return video_path

# Main function with Gradio interface
def main():
    iface = gr.Interface(
        fn=train_and_play,
        inputs=[
            gr.Number(label="Reward Size", value=1),
            gr.Number(label="Penalty Size", value=-1),
            gr.Number(label="Platform Reward", value=5),
            gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
        ],
        outputs="video",
        live=False  # Disable auto-generation on slider changes
    )
    iface.launch()

if __name__ == "__main__":
    main()