Commit
·
83e5230
1
Parent(s):
3a3ae89
main files
Browse files- .gitattributes +5 -1
- README.md +6 -6
- atari.py +127 -0
- evaluate.ipynb +77 -0
- networks.py +38 -0
- offline_config.json +37 -0
- online_config.json +37 -0
.gitattributes
CHANGED
@@ -23,7 +23,6 @@
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
26 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar filter=lfs diff=lfs merge=lfs -text
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/**/* filter=lfs diff=lfs merge=lfs -text
|
39 |
+
evaluations/**/* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -26,23 +26,23 @@ The sparse model parameters were obtained with [EauDeQN](https://arxiv.org/pdf/2
|
|
26 |
|
27 |
5 seeds are available for each configuration which makes a total of **750 available models** 📈.
|
28 |
|
29 |
-
The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑🏫 It uses JAX 🚀 The hyperparameters used during training are reported in [
|
30 |
|
31 |
The training code is available soon ⏳
|
32 |
|
33 |
-
### Model performances
|
34 |
-
| <div style="width:300px; font-size: 30px; font-family:Serif; font-name:Times New Roman" > **EauDeDQN** and **EauDeCQL** achieve high sparsity while keeping performances high. <br> Published at [RLDM](https://arxiv.org/pdf/2503.01437)✨ </br> <div style="font-size: 16px"> <details> <summary id=games>List of Atari games</summary> *BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, VideoPinball.* </details> </div> </div> | <img src="
|
35 |
| :-: | :-: |
|
36 |
|
|
|
|
|
37 |
## User installation
|
38 |
Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
|
39 |
```bash
|
40 |
python3.10 -m venv env
|
41 |
source env/bin/activate
|
42 |
-
pip install --upgrade pip
|
43 |
-
pip install numpy==1.23.5 # to avoid numpy==2.XX
|
44 |
pip install -r requirements.txt
|
45 |
-
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
46 |
```
|
47 |
|
48 |
## Citing `Eau De Q-Network`
|
|
|
26 |
|
27 |
5 seeds are available for each configuration which makes a total of **750 available models** 📈.
|
28 |
|
29 |
+
The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑🏫 It uses JAX 🚀 The hyperparameters used during training are reported in [online_config.json](./online_config.json) and [offline_config.json](./offline_config.json) 🔧
|
30 |
|
31 |
The training code is available soon ⏳
|
32 |
|
33 |
+
### Model sparsity & performances
|
34 |
+
| <div style="width:300px; font-size: 30px; font-family:Serif; font-name:Times New Roman" > **EauDeDQN** and **EauDeCQL** achieve high sparsity while keeping performances high. <br> Published at [RLDM](https://arxiv.org/pdf/2503.01437)✨ </br> <div style="font-size: 16px"> <details> <summary id=games>List of Atari games</summary> *BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, VideoPinball.* </details> </div> </div> | <img src="sparsities.png" alt="drawing" width="600px"/> |
|
35 |
| :-: | :-: |
|
36 |
|
37 |
+
The episodic returns and lenghts are available in the [evaluations](./evaluations/) folder 🔬
|
38 |
+
|
39 |
## User installation
|
40 |
Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
|
41 |
```bash
|
42 |
python3.10 -m venv env
|
43 |
source env/bin/activate
|
44 |
+
pip install --upgrade pip setuptools wheel
|
|
|
45 |
pip install -r requirements.txt
|
|
|
46 |
```
|
47 |
|
48 |
## Citing `Eau De Q-Network`
|
atari.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import ale_py
|
6 |
+
from typing import Tuple, Dict
|
7 |
+
from gymnasium.wrappers import RecordVideo
|
8 |
+
import gymnasium as gym
|
9 |
+
import numpy as np
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
|
15 |
+
class AtariEnv:
|
16 |
+
def __init__(self, name: str) -> None:
|
17 |
+
self.name = name
|
18 |
+
self.state_height, self.state_width = (84, 84)
|
19 |
+
self.n_stacked_frames = 4
|
20 |
+
self.n_skipped_frames = 4
|
21 |
+
|
22 |
+
gym.register_envs(ale_py) # To use ale with gym which speeds up step()
|
23 |
+
self.env = gym.make(
|
24 |
+
f"ALE/{self.name}-v5",
|
25 |
+
full_action_space=False,
|
26 |
+
frameskip=1,
|
27 |
+
repeat_action_probability=0.25,
|
28 |
+
max_num_frames_per_episode=100_000,
|
29 |
+
continuous=False,
|
30 |
+
continuous_action_threshold=0.0,
|
31 |
+
render_mode="rgb_array",
|
32 |
+
).env
|
33 |
+
|
34 |
+
self.n_actions = self.env.action_space.n
|
35 |
+
self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
|
36 |
+
self.screen_buffer = [
|
37 |
+
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
|
38 |
+
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
|
39 |
+
]
|
40 |
+
|
41 |
+
@property
|
42 |
+
def observation(self) -> np.ndarray:
|
43 |
+
return np.copy(self.state_[:, :, -1])
|
44 |
+
|
45 |
+
@property
|
46 |
+
def state(self) -> np.ndarray:
|
47 |
+
return jnp.array(self.state_, dtype=jnp.float32)
|
48 |
+
|
49 |
+
def reset(self) -> None:
|
50 |
+
self.env.reset()
|
51 |
+
|
52 |
+
self.n_steps = 0
|
53 |
+
|
54 |
+
self.env.env.ale.getScreenGrayscale(self.screen_buffer[0])
|
55 |
+
self.screen_buffer[1].fill(0)
|
56 |
+
|
57 |
+
self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
|
58 |
+
self.state_[:, :, -1] = self.resize()
|
59 |
+
|
60 |
+
def step(self, action: jnp.int8) -> Tuple[float, bool]:
|
61 |
+
reward = 0
|
62 |
+
|
63 |
+
for idx_frame in range(self.n_skipped_frames):
|
64 |
+
_, reward_, terminal, _, _ = self.env.step(action)
|
65 |
+
|
66 |
+
reward += reward_
|
67 |
+
|
68 |
+
if idx_frame >= self.n_skipped_frames - 2:
|
69 |
+
t = idx_frame - (self.n_skipped_frames - 2)
|
70 |
+
self.env.env.ale.getScreenGrayscale(self.screen_buffer[t])
|
71 |
+
|
72 |
+
if terminal:
|
73 |
+
break
|
74 |
+
|
75 |
+
self.state_ = np.roll(self.state_, -1, axis=-1)
|
76 |
+
self.state_[:, :, -1] = self.pool_and_resize()
|
77 |
+
|
78 |
+
self.n_steps += 1
|
79 |
+
|
80 |
+
return reward, terminal
|
81 |
+
|
82 |
+
def pool_and_resize(self) -> np.ndarray:
|
83 |
+
np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])
|
84 |
+
|
85 |
+
return self.resize()
|
86 |
+
|
87 |
+
def resize(self):
|
88 |
+
return np.asarray(
|
89 |
+
cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
|
90 |
+
dtype=np.uint8,
|
91 |
+
)
|
92 |
+
|
93 |
+
def evaluate_one_simulation(
|
94 |
+
self,
|
95 |
+
q,
|
96 |
+
q_params: Dict,
|
97 |
+
horizon: int,
|
98 |
+
eps_eval: float,
|
99 |
+
exploration_key: jax.random.PRNGKey,
|
100 |
+
video_path: str,
|
101 |
+
) -> float:
|
102 |
+
ale = self.env.env.ale
|
103 |
+
self.env = RecordVideo(
|
104 |
+
self.env,
|
105 |
+
video_folder=video_path if video_path is not None else ".",
|
106 |
+
name_prefix="",
|
107 |
+
episode_trigger=lambda x: video_path is not None,
|
108 |
+
)
|
109 |
+
self.env.env.ale = ale
|
110 |
+
|
111 |
+
sun_reward = 0
|
112 |
+
terminal = False
|
113 |
+
self.reset()
|
114 |
+
|
115 |
+
while not terminal and self.n_steps < horizon:
|
116 |
+
exploration_key, key = jax.random.split(exploration_key)
|
117 |
+
if jax.random.uniform(key) < eps_eval:
|
118 |
+
action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
|
119 |
+
else:
|
120 |
+
action = q.best_action(q_params, self.state)
|
121 |
+
|
122 |
+
reward, terminal = self.step(action)
|
123 |
+
|
124 |
+
sun_reward += reward
|
125 |
+
|
126 |
+
self.env.close()
|
127 |
+
return sun_reward, terminal
|
evaluate.ipynb
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"%autoreload 2"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import jax\n",
|
20 |
+
"import jax.numpy as jnp\n",
|
21 |
+
"import pickle\n",
|
22 |
+
"from atari import AtariEnv\n",
|
23 |
+
"from networks import QNetwork\n",
|
24 |
+
"\n",
|
25 |
+
"# ------- START TO MODIFY ------- #\n",
|
26 |
+
"ALGO = \"eaudedqn\" # choose between eaudedqn, polyprunedqn, dqn, eaudecql, polyprunecql, and cql.\n",
|
27 |
+
"GAME = \"SpaceInvaders\" # choose between BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, and VideoPinball.\n",
|
28 |
+
"FEATURE_SIZE = 32 # choose between 32, 512, and 2048.\n",
|
29 |
+
"NETWORK_SEED = 1 # choose between 1, 2, 3, 4, and 5.\n",
|
30 |
+
"EVALUATION_SEED = 0\n",
|
31 |
+
"HORIZON = 27000\n",
|
32 |
+
"EPSILON = 0.01\n",
|
33 |
+
"RECORD_VIDEO = False\n",
|
34 |
+
"# ------- END TO MODIFY ------- #\n",
|
35 |
+
"\n",
|
36 |
+
"params_path = f\"models/{GAME}/{ALGO}/feature_size_{FEATURE_SIZE}_seed_{NETWORK_SEED}\"\n",
|
37 |
+
"\n",
|
38 |
+
"env = AtariEnv(GAME)\n",
|
39 |
+
"\n",
|
40 |
+
"q = QNetwork([32, 64, 64, FEATURE_SIZE], env.n_actions)\n",
|
41 |
+
"\n",
|
42 |
+
"with open(params_path, \"rb\") as handle:\n",
|
43 |
+
" q_params = pickle.load(handle)\n",
|
44 |
+
"\n",
|
45 |
+
"return_, absorbing = env.evaluate_one_simulation(\n",
|
46 |
+
" q, q_params, HORIZON, EPSILON, jax.random.PRNGKey(EVALUATION_SEED), params_path + \"_eval\" if RECORD_VIDEO else None\n",
|
47 |
+
")\n",
|
48 |
+
"print(\"Undiscounted return:\", return_)\n",
|
49 |
+
"print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)\n",
|
50 |
+
"non_zeros = sum(jax.tree.leaves(jax.tree.map(jnp.count_nonzero, q_params)))\n",
|
51 |
+
"n_params = sum(jax.tree.leaves(jax.tree.map(jnp.size, q_params)))\n",
|
52 |
+
"print(\"Spartity level:\", (1 - jnp.float32(non_zeros) / jnp.float32(n_params)))"
|
53 |
+
]
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"metadata": {
|
57 |
+
"kernelspec": {
|
58 |
+
"display_name": "env",
|
59 |
+
"language": "python",
|
60 |
+
"name": "python3"
|
61 |
+
},
|
62 |
+
"language_info": {
|
63 |
+
"codemirror_mode": {
|
64 |
+
"name": "ipython",
|
65 |
+
"version": 3
|
66 |
+
},
|
67 |
+
"file_extension": ".py",
|
68 |
+
"mimetype": "text/x-python",
|
69 |
+
"name": "python",
|
70 |
+
"nbconvert_exporter": "python",
|
71 |
+
"pygments_lexer": "ipython3",
|
72 |
+
"version": "3.11.5"
|
73 |
+
}
|
74 |
+
},
|
75 |
+
"nbformat": 4,
|
76 |
+
"nbformat_minor": 2
|
77 |
+
}
|
networks.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
from functools import partial
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
|
7 |
+
|
8 |
+
class DQNNet(nn.Module):
|
9 |
+
features: Sequence[int]
|
10 |
+
n_actions: int
|
11 |
+
|
12 |
+
@nn.compact
|
13 |
+
def __call__(self, x):
|
14 |
+
initializer = nn.initializers.xavier_uniform()
|
15 |
+
x = nn.relu(
|
16 |
+
nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(
|
17 |
+
jnp.array(x, ndmin=4) / 255.0
|
18 |
+
)
|
19 |
+
)
|
20 |
+
x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x))
|
21 |
+
x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x))
|
22 |
+
x = x.reshape((x.shape[0], -1))
|
23 |
+
|
24 |
+
x = jnp.squeeze(x)
|
25 |
+
|
26 |
+
for idx_layer in range(3, len(self.features)):
|
27 |
+
x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x)))
|
28 |
+
|
29 |
+
return nn.Dense(self.n_actions, kernel_init=initializer)(x)
|
30 |
+
|
31 |
+
|
32 |
+
class QNetwork:
|
33 |
+
def __init__(self, features: Sequence[int], n_actions: int) -> None:
|
34 |
+
self.network = DQNNet(features, n_actions)
|
35 |
+
|
36 |
+
@partial(jax.jit, static_argnames="self")
|
37 |
+
def best_action(self, params, state: jnp.ndarray) -> jnp.int8:
|
38 |
+
return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8)
|
offline_config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"shared_parameters": {
|
3 |
+
"features": [
|
4 |
+
32,
|
5 |
+
64,
|
6 |
+
64,
|
7 |
+
"Feature Size"
|
8 |
+
],
|
9 |
+
"replay_buffer_capacity": 50000,
|
10 |
+
"batch_size": 32,
|
11 |
+
"update_horizon": 1,
|
12 |
+
"gamma": 0.99,
|
13 |
+
"learning_rate": 5e-05,
|
14 |
+
"architecture_type": "cnn",
|
15 |
+
"target_update_frequency": 2000,
|
16 |
+
"n_buffers_to_load": 5,
|
17 |
+
"n_epochs": 50,
|
18 |
+
"n_fitting_steps": 62500
|
19 |
+
},
|
20 |
+
"eaudecql": {
|
21 |
+
"n_networks": 5,
|
22 |
+
"max_noise": 3.0,
|
23 |
+
"max_speed": 0.01,
|
24 |
+
"reset_optimizer": true,
|
25 |
+
"alpha_cql": 0.1
|
26 |
+
},
|
27 |
+
"polyprunecql": {
|
28 |
+
"sparcity_start_step": 625000,
|
29 |
+
"sparcity_end_step": 2500000,
|
30 |
+
"sparcity_update_freq": 1000,
|
31 |
+
"final_sparsity": 0.95,
|
32 |
+
"alpha_cql": 0.1
|
33 |
+
},
|
34 |
+
"cql": {
|
35 |
+
"alpha_cql": 0.1
|
36 |
+
}
|
37 |
+
}
|
online_config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"shared_parameters": {
|
3 |
+
"replay_buffer_capacity": 1000000,
|
4 |
+
"batch_size": 32,
|
5 |
+
"update_horizon": 1,
|
6 |
+
"gamma": 0.99,
|
7 |
+
"learning_rate": 6.25e-05,
|
8 |
+
"horizon": 27000,
|
9 |
+
"n_epochs": 40,
|
10 |
+
"n_training_steps_per_epoch": 250000,
|
11 |
+
"n_initial_samples": 20000,
|
12 |
+
"epsilon_end": 0.01,
|
13 |
+
"epsilon_duration": 250000.0,
|
14 |
+
"target_update_frequency": 8000,
|
15 |
+
"update_to_data": 4.0,
|
16 |
+
"features": [
|
17 |
+
32,
|
18 |
+
64,
|
19 |
+
64,
|
20 |
+
"Feature Size"
|
21 |
+
],
|
22 |
+
"architecture_type": "cnn"
|
23 |
+
},
|
24 |
+
"eaudedqn": {
|
25 |
+
"n_networks": 5,
|
26 |
+
"max_noise": 3.0,
|
27 |
+
"max_speed": 0.01,
|
28 |
+
"reset_optimizer": true
|
29 |
+
},
|
30 |
+
"polyprunedqn": {
|
31 |
+
"sparcity_start_step": 2000000,
|
32 |
+
"sparcity_end_step": 8000000,
|
33 |
+
"sparcity_update_freq": 4000,
|
34 |
+
"final_sparsity": 0.95
|
35 |
+
},
|
36 |
+
"dqn": {}
|
37 |
+
}
|