TheoVincent commited on
Commit
83e5230
·
1 Parent(s): 3a3ae89

main files

Browse files
Files changed (7) hide show
  1. .gitattributes +5 -1
  2. README.md +6 -6
  3. atari.py +127 -0
  4. evaluate.ipynb +77 -0
  5. networks.py +38 -0
  6. offline_config.json +37 -0
  7. online_config.json +37 -0
.gitattributes CHANGED
@@ -23,7 +23,6 @@
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  *.tar.* filter=lfs diff=lfs merge=lfs -text
27
  *.tar filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.gif filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ models/**/* filter=lfs diff=lfs merge=lfs -text
39
+ evaluations/**/* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -26,23 +26,23 @@ The sparse model parameters were obtained with [EauDeQN](https://arxiv.org/pdf/2
26
 
27
  5 seeds are available for each configuration which makes a total of **750 available models** 📈.
28
 
29
- The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑‍🏫 It uses JAX 🚀 The hyperparameters used during training are reported in [config.json](./config.json) 🔧
30
 
31
  The training code is available soon ⏳
32
 
33
- ### Model performances
34
- | <div style="width:300px; font-size: 30px; font-family:Serif; font-name:Times New Roman" > **EauDeDQN** and **EauDeCQL** achieve high sparsity while keeping performances high. <br> Published at [RLDM](https://arxiv.org/pdf/2503.01437)✨ </br> <div style="font-size: 16px"> <details> <summary id=games>List of Atari games</summary> *BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, VideoPinball.* </details> </div> </div> | <img src="performances.png" alt="drawing" width="600px"/> |
35
  | :-: | :-: |
36
 
 
 
37
  ## User installation
38
  Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
39
  ```bash
40
  python3.10 -m venv env
41
  source env/bin/activate
42
- pip install --upgrade pip
43
- pip install numpy==1.23.5 # to avoid numpy==2.XX
44
  pip install -r requirements.txt
45
- pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
46
  ```
47
 
48
  ## Citing `Eau De Q-Network`
 
26
 
27
  5 seeds are available for each configuration which makes a total of **750 available models** 📈.
28
 
29
+ The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑‍🏫 It uses JAX 🚀 The hyperparameters used during training are reported in [online_config.json](./online_config.json) and [offline_config.json](./offline_config.json) 🔧
30
 
31
  The training code is available soon ⏳
32
 
33
+ ### Model sparsity & performances
34
+ | <div style="width:300px; font-size: 30px; font-family:Serif; font-name:Times New Roman" > **EauDeDQN** and **EauDeCQL** achieve high sparsity while keeping performances high. <br> Published at [RLDM](https://arxiv.org/pdf/2503.01437)✨ </br> <div style="font-size: 16px"> <details> <summary id=games>List of Atari games</summary> *BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, VideoPinball.* </details> </div> </div> | <img src="sparsities.png" alt="drawing" width="600px"/> |
35
  | :-: | :-: |
36
 
37
+ The episodic returns and lenghts are available in the [evaluations](./evaluations/) folder 🔬
38
+
39
  ## User installation
40
  Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
41
  ```bash
42
  python3.10 -m venv env
43
  source env/bin/activate
44
+ pip install --upgrade pip setuptools wheel
 
45
  pip install -r requirements.txt
 
46
  ```
47
 
48
  ## Citing `Eau De Q-Network`
atari.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
3
+ """
4
+
5
+ import ale_py
6
+ from typing import Tuple, Dict
7
+ from gymnasium.wrappers import RecordVideo
8
+ import gymnasium as gym
9
+ import numpy as np
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import cv2
13
+
14
+
15
+ class AtariEnv:
16
+ def __init__(self, name: str) -> None:
17
+ self.name = name
18
+ self.state_height, self.state_width = (84, 84)
19
+ self.n_stacked_frames = 4
20
+ self.n_skipped_frames = 4
21
+
22
+ gym.register_envs(ale_py) # To use ale with gym which speeds up step()
23
+ self.env = gym.make(
24
+ f"ALE/{self.name}-v5",
25
+ full_action_space=False,
26
+ frameskip=1,
27
+ repeat_action_probability=0.25,
28
+ max_num_frames_per_episode=100_000,
29
+ continuous=False,
30
+ continuous_action_threshold=0.0,
31
+ render_mode="rgb_array",
32
+ ).env
33
+
34
+ self.n_actions = self.env.action_space.n
35
+ self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
36
+ self.screen_buffer = [
37
+ np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
38
+ np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
39
+ ]
40
+
41
+ @property
42
+ def observation(self) -> np.ndarray:
43
+ return np.copy(self.state_[:, :, -1])
44
+
45
+ @property
46
+ def state(self) -> np.ndarray:
47
+ return jnp.array(self.state_, dtype=jnp.float32)
48
+
49
+ def reset(self) -> None:
50
+ self.env.reset()
51
+
52
+ self.n_steps = 0
53
+
54
+ self.env.env.ale.getScreenGrayscale(self.screen_buffer[0])
55
+ self.screen_buffer[1].fill(0)
56
+
57
+ self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
58
+ self.state_[:, :, -1] = self.resize()
59
+
60
+ def step(self, action: jnp.int8) -> Tuple[float, bool]:
61
+ reward = 0
62
+
63
+ for idx_frame in range(self.n_skipped_frames):
64
+ _, reward_, terminal, _, _ = self.env.step(action)
65
+
66
+ reward += reward_
67
+
68
+ if idx_frame >= self.n_skipped_frames - 2:
69
+ t = idx_frame - (self.n_skipped_frames - 2)
70
+ self.env.env.ale.getScreenGrayscale(self.screen_buffer[t])
71
+
72
+ if terminal:
73
+ break
74
+
75
+ self.state_ = np.roll(self.state_, -1, axis=-1)
76
+ self.state_[:, :, -1] = self.pool_and_resize()
77
+
78
+ self.n_steps += 1
79
+
80
+ return reward, terminal
81
+
82
+ def pool_and_resize(self) -> np.ndarray:
83
+ np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])
84
+
85
+ return self.resize()
86
+
87
+ def resize(self):
88
+ return np.asarray(
89
+ cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
90
+ dtype=np.uint8,
91
+ )
92
+
93
+ def evaluate_one_simulation(
94
+ self,
95
+ q,
96
+ q_params: Dict,
97
+ horizon: int,
98
+ eps_eval: float,
99
+ exploration_key: jax.random.PRNGKey,
100
+ video_path: str,
101
+ ) -> float:
102
+ ale = self.env.env.ale
103
+ self.env = RecordVideo(
104
+ self.env,
105
+ video_folder=video_path if video_path is not None else ".",
106
+ name_prefix="",
107
+ episode_trigger=lambda x: video_path is not None,
108
+ )
109
+ self.env.env.ale = ale
110
+
111
+ sun_reward = 0
112
+ terminal = False
113
+ self.reset()
114
+
115
+ while not terminal and self.n_steps < horizon:
116
+ exploration_key, key = jax.random.split(exploration_key)
117
+ if jax.random.uniform(key) < eps_eval:
118
+ action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
119
+ else:
120
+ action = q.best_action(q_params, self.state)
121
+
122
+ reward, terminal = self.step(action)
123
+
124
+ sun_reward += reward
125
+
126
+ self.env.close()
127
+ return sun_reward, terminal
evaluate.ipynb ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import jax\n",
20
+ "import jax.numpy as jnp\n",
21
+ "import pickle\n",
22
+ "from atari import AtariEnv\n",
23
+ "from networks import QNetwork\n",
24
+ "\n",
25
+ "# ------- START TO MODIFY ------- #\n",
26
+ "ALGO = \"eaudedqn\" # choose between eaudedqn, polyprunedqn, dqn, eaudecql, polyprunecql, and cql.\n",
27
+ "GAME = \"SpaceInvaders\" # choose between BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, and VideoPinball.\n",
28
+ "FEATURE_SIZE = 32 # choose between 32, 512, and 2048.\n",
29
+ "NETWORK_SEED = 1 # choose between 1, 2, 3, 4, and 5.\n",
30
+ "EVALUATION_SEED = 0\n",
31
+ "HORIZON = 27000\n",
32
+ "EPSILON = 0.01\n",
33
+ "RECORD_VIDEO = False\n",
34
+ "# ------- END TO MODIFY ------- #\n",
35
+ "\n",
36
+ "params_path = f\"models/{GAME}/{ALGO}/feature_size_{FEATURE_SIZE}_seed_{NETWORK_SEED}\"\n",
37
+ "\n",
38
+ "env = AtariEnv(GAME)\n",
39
+ "\n",
40
+ "q = QNetwork([32, 64, 64, FEATURE_SIZE], env.n_actions)\n",
41
+ "\n",
42
+ "with open(params_path, \"rb\") as handle:\n",
43
+ " q_params = pickle.load(handle)\n",
44
+ "\n",
45
+ "return_, absorbing = env.evaluate_one_simulation(\n",
46
+ " q, q_params, HORIZON, EPSILON, jax.random.PRNGKey(EVALUATION_SEED), params_path + \"_eval\" if RECORD_VIDEO else None\n",
47
+ ")\n",
48
+ "print(\"Undiscounted return:\", return_)\n",
49
+ "print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)\n",
50
+ "non_zeros = sum(jax.tree.leaves(jax.tree.map(jnp.count_nonzero, q_params)))\n",
51
+ "n_params = sum(jax.tree.leaves(jax.tree.map(jnp.size, q_params)))\n",
52
+ "print(\"Spartity level:\", (1 - jnp.float32(non_zeros) / jnp.float32(n_params)))"
53
+ ]
54
+ }
55
+ ],
56
+ "metadata": {
57
+ "kernelspec": {
58
+ "display_name": "env",
59
+ "language": "python",
60
+ "name": "python3"
61
+ },
62
+ "language_info": {
63
+ "codemirror_mode": {
64
+ "name": "ipython",
65
+ "version": 3
66
+ },
67
+ "file_extension": ".py",
68
+ "mimetype": "text/x-python",
69
+ "name": "python",
70
+ "nbconvert_exporter": "python",
71
+ "pygments_lexer": "ipython3",
72
+ "version": "3.11.5"
73
+ }
74
+ },
75
+ "nbformat": 4,
76
+ "nbformat_minor": 2
77
+ }
networks.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ from functools import partial
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+
7
+
8
+ class DQNNet(nn.Module):
9
+ features: Sequence[int]
10
+ n_actions: int
11
+
12
+ @nn.compact
13
+ def __call__(self, x):
14
+ initializer = nn.initializers.xavier_uniform()
15
+ x = nn.relu(
16
+ nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(
17
+ jnp.array(x, ndmin=4) / 255.0
18
+ )
19
+ )
20
+ x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x))
21
+ x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x))
22
+ x = x.reshape((x.shape[0], -1))
23
+
24
+ x = jnp.squeeze(x)
25
+
26
+ for idx_layer in range(3, len(self.features)):
27
+ x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x)))
28
+
29
+ return nn.Dense(self.n_actions, kernel_init=initializer)(x)
30
+
31
+
32
+ class QNetwork:
33
+ def __init__(self, features: Sequence[int], n_actions: int) -> None:
34
+ self.network = DQNNet(features, n_actions)
35
+
36
+ @partial(jax.jit, static_argnames="self")
37
+ def best_action(self, params, state: jnp.ndarray) -> jnp.int8:
38
+ return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8)
offline_config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shared_parameters": {
3
+ "features": [
4
+ 32,
5
+ 64,
6
+ 64,
7
+ "Feature Size"
8
+ ],
9
+ "replay_buffer_capacity": 50000,
10
+ "batch_size": 32,
11
+ "update_horizon": 1,
12
+ "gamma": 0.99,
13
+ "learning_rate": 5e-05,
14
+ "architecture_type": "cnn",
15
+ "target_update_frequency": 2000,
16
+ "n_buffers_to_load": 5,
17
+ "n_epochs": 50,
18
+ "n_fitting_steps": 62500
19
+ },
20
+ "eaudecql": {
21
+ "n_networks": 5,
22
+ "max_noise": 3.0,
23
+ "max_speed": 0.01,
24
+ "reset_optimizer": true,
25
+ "alpha_cql": 0.1
26
+ },
27
+ "polyprunecql": {
28
+ "sparcity_start_step": 625000,
29
+ "sparcity_end_step": 2500000,
30
+ "sparcity_update_freq": 1000,
31
+ "final_sparsity": 0.95,
32
+ "alpha_cql": 0.1
33
+ },
34
+ "cql": {
35
+ "alpha_cql": 0.1
36
+ }
37
+ }
online_config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shared_parameters": {
3
+ "replay_buffer_capacity": 1000000,
4
+ "batch_size": 32,
5
+ "update_horizon": 1,
6
+ "gamma": 0.99,
7
+ "learning_rate": 6.25e-05,
8
+ "horizon": 27000,
9
+ "n_epochs": 40,
10
+ "n_training_steps_per_epoch": 250000,
11
+ "n_initial_samples": 20000,
12
+ "epsilon_end": 0.01,
13
+ "epsilon_duration": 250000.0,
14
+ "target_update_frequency": 8000,
15
+ "update_to_data": 4.0,
16
+ "features": [
17
+ 32,
18
+ 64,
19
+ 64,
20
+ "Feature Size"
21
+ ],
22
+ "architecture_type": "cnn"
23
+ },
24
+ "eaudedqn": {
25
+ "n_networks": 5,
26
+ "max_noise": 3.0,
27
+ "max_speed": 0.01,
28
+ "reset_optimizer": true
29
+ },
30
+ "polyprunedqn": {
31
+ "sparcity_start_step": 2000000,
32
+ "sparcity_end_step": 8000000,
33
+ "sparcity_update_freq": 4000,
34
+ "final_sparsity": 0.95
35
+ },
36
+ "dqn": {}
37
+ }