Skip to content

Commit 9c8ef81

Browse files
committed
Merge branch 'master' into soft-actor-critic
2 parents bd15d96 + b25256c commit 9c8ef81

31 files changed

+1073
-141
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ For more information, you can refer to [ChainerRL's documentation](http://chaine
4949
| PCL (Path Consistency Learning) |||||
5050
| PPO ||| x | x |
5151
| TRPO ||| x | x |
52+
| TD3 | x || x | x |
5253

5354
Following algorithms have been implemented in ChainerRL:
5455
- A3C (Asynchronous Advantage Actor-Critic)
@@ -63,6 +64,7 @@ Following algorithms have been implemented in ChainerRL:
6364
- PCL (Path Consistency Learning)
6465
- PPO (Proximal Policy Optimization)
6566
- TRPO (Trust Region Policy Optimization)
67+
- TD3 (Twin Delayed Deep Deterministic policy gradient algorithm)
6668

6769
Q-function based algorithms such as DQN can utilize a Normalized Advantage Function (NAF) to tackle continuous-action problems as well as DQN-like discrete output networks.
6870

@@ -84,6 +86,7 @@ The following papers have been implemented in ChainerRL:
8486
- [Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477)
8587
- [Sample Efficient Actor-Critic with Experience Replay](https://arxiv.org/abs/1611.01224)
8688
- [Bridging the Gap Between Value and Policy Based Reinforcement Learning](https://arxiv.org/abs/1702.08892)
89+
- [Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477)
8790

8891

8992
## Visualization

chainerrl/agents/ddpg.py

+2
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,11 @@ def batch_observe_and_train(
435435
next_state=batch_obs[i],
436436
next_action=None,
437437
is_state_terminal=batch_done[i],
438+
env_id=i,
438439
)
439440
if batch_reset[i] or batch_done[i]:
440441
self.batch_last_obs[i] = None
442+
self.replay_buffer.stop_current_episode(env_id=i)
441443
self.replay_updater.update_if_necessary(self.t)
442444

443445
def batch_observe(self, batch_obs, batch_reward,

chainerrl/agents/dqn.py

+2
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,11 @@ def batch_observe_and_train(self, batch_obs, batch_reward,
470470
next_state=batch_obs[i],
471471
next_action=None,
472472
is_state_terminal=batch_done[i],
473+
env_id=i,
473474
)
474475
if batch_reset[i] or batch_done[i]:
475476
self.batch_last_obs[i] = None
477+
self.replay_buffer.stop_current_episode(env_id=i)
476478
self.replay_updater.update_if_necessary(self.t)
477479

478480
def batch_observe(self, batch_obs, batch_reward,

chainerrl/agents/td3.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def update(self, experiences, errors_out=None):
241241
self.update_policy(batch)
242242
self.sync_target_network()
243243

244-
def select_greedy_action(self, obs):
244+
def select_onpolicy_action(self, obs):
245245
with chainer.no_backprop_mode(), chainer.using_config('train', False):
246246
s = self.batch_states([obs], self.xp, self.phi)
247247
action = self.policy(s).sample().array
@@ -255,8 +255,9 @@ def act_and_train(self, obs, reward):
255255
and self.policy_optimizer.t == 0):
256256
action = self.burnin_action_func()
257257
else:
258-
greedy_action = self.select_greedy_action(obs)
259-
action = self.explorer.select_action(self.t, lambda: greedy_action)
258+
onpolicy_action = self.select_onpolicy_action(obs)
259+
action = self.explorer.select_action(
260+
self.t, lambda: onpolicy_action)
260261
self.t += 1
261262

262263
if self.last_state is not None:
@@ -278,16 +279,16 @@ def act_and_train(self, obs, reward):
278279
return self.last_action
279280

280281
def act(self, obs):
281-
return self.select_greedy_action(obs)
282+
return self.select_onpolicy_action(obs)
282283

283-
def batch_select_greedy_action(self, batch_obs):
284+
def batch_select_onpolicy_action(self, batch_obs):
284285
with chainer.using_config('train', False), chainer.no_backprop_mode():
285286
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
286287
batch_action = self.policy(batch_xs).sample().array
287288
return list(cuda.to_cpu(batch_action))
288289

289290
def batch_act(self, batch_obs):
290-
return self.batch_select_greedy_action(batch_obs)
291+
return self.batch_select_onpolicy_action(batch_obs)
291292

292293
def batch_act_and_train(self, batch_obs):
293294
"""Select a batch of actions for training.
@@ -304,11 +305,12 @@ def batch_act_and_train(self, batch_obs):
304305
batch_action = [self.burnin_action_func()
305306
for _ in range(len(batch_obs))]
306307
else:
307-
batch_greedy_action = self.batch_select_greedy_action(batch_obs)
308+
batch_onpolicy_action = self.batch_select_onpolicy_action(
309+
batch_obs)
308310
batch_action = [
309311
self.explorer.select_action(
310-
self.t, lambda: batch_greedy_action[i])
311-
for i in range(len(batch_greedy_action))]
312+
self.t, lambda: batch_onpolicy_action[i])
313+
for i in range(len(batch_onpolicy_action))]
312314

313315
self.batch_last_obs = list(batch_obs)
314316
self.batch_last_action = list(batch_action)
@@ -329,9 +331,11 @@ def batch_observe_and_train(
329331
next_state=batch_obs[i],
330332
next_action=None,
331333
is_state_terminal=batch_done[i],
334+
env_id=i,
332335
)
333336
if batch_reset[i] or batch_done[i]:
334337
self.batch_last_obs[i] = None
338+
self.replay_buffer.stop_current_episode(env_id=i)
335339
self.replay_updater.update_if_necessary(self.t)
336340

337341
def batch_observe(self, batch_obs, batch_reward,

chainerrl/experiments/evaluator.py

+1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def eval_performance(env, agent, n_steps, n_episodes, max_episode_len=None,
230230
Args:
231231
env (Environment): Environment used for evaluation
232232
agent (Agent): Agent to evaluate.
233+
n_steps (int): Number of timesteps to evaluate for.
233234
n_episodes (int): Number of evaluation episodes.
234235
max_episode_len (int or None): If specified, episodes longer than this
235236
value will be truncated.

chainerrl/initializers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from chainerrl.initializers.constant import VarianceScalingConstant # NOQA
22

3+
from chainerrl.initializers.orthogonal import Orthogonal # NOQA
4+
35
# LeCunNormal was merged into Chainer v3, thus removed from ChainerRL.
46
# For backward compatibility, it is still imported in this namespace.
57
from chainer.initializers import LeCunNormal # NOQA

chainerrl/initializers/orthogonal.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
This is copied from https://github.com/chainer/chainer/pull/6031 and will be
3+
unnecessary once the PR is merged to Chainer.
4+
"""
5+
import functools
6+
import operator
7+
8+
import numpy
9+
10+
from chainer import cuda
11+
from chainer import initializer
12+
13+
14+
# Only Chainer v6 or later has chainer.utils.size_of_shape
15+
def size_of_shape(shape):
16+
return functools.reduce(operator.mul, shape, 1)
17+
18+
19+
_orthogonal_constraints = { # (assert emb., assert proj.)
20+
'auto': (False, False),
21+
'projection': (False, True),
22+
'embedding': (True, False),
23+
'basis': (True, True),
24+
}
25+
26+
27+
# Original code forked from MIT licensed keras project
28+
# https://github.com/fchollet/keras/blob/master/keras/initializations.py
29+
30+
class Orthogonal(initializer.Initializer):
31+
"""Initializes array with an orthogonal system.
32+
33+
This initializer first makes a matrix of the same shape as the
34+
array to be initialized whose elements are drawn independently from
35+
standard Gaussian distribution.
36+
Next, it applies QR decomposition to (the transpose of) the matrix.
37+
To make the decomposition (almost surely) unique, we require the diagonal
38+
of the triangular matrix R to be non-negative (see e.g. Edelman & Rao,
39+
https://web.eecs.umich.edu/~rajnrao/Acta05rmt.pdf).
40+
Then, it initializes the array with the (semi-)orthogonal matrix Q.
41+
Finally, the array is multiplied by the constant ``scale``.
42+
43+
If the ``ndim`` of the input array is more than 2, we consider the array
44+
to be a matrix by concatenating all axes except the first one.
45+
46+
The number of vectors consisting of the orthogonal system
47+
(i.e. first element of the shape of the array) must be equal to or smaller
48+
than the dimension of each vector (i.e. second element of the shape of
49+
the array).
50+
51+
Attributes:
52+
scale (float): A constant to be multiplied by.
53+
dtype: Data type specifier.
54+
mode (str): Assertion on the initialized shape.
55+
``'auto'`` (default), ``'projection'`` (before v6),
56+
``'embedding'``, or ``'basis'``.
57+
58+
Reference: Saxe et al., https://arxiv.org/abs/1312.6120
59+
60+
"""
61+
62+
def __init__(self, scale=1.1, dtype=None, mode='auto'):
63+
self.scale = scale
64+
self.mode = mode
65+
try:
66+
self._checks = _orthogonal_constraints[mode]
67+
except KeyError:
68+
raise ValueError(
69+
'Invalid mode: {}. Choose from {}.'.format(
70+
repr(mode),
71+
', '.join(repr(m) for m in _orthogonal_constraints)))
72+
super(Orthogonal, self).__init__(dtype)
73+
74+
# TODO(Kenta Oono)
75+
# How do we treat overcomplete base-system case?
76+
def __call__(self, array):
77+
if self.dtype is not None:
78+
assert array.dtype == self.dtype
79+
xp = cuda.get_array_module(array)
80+
if not array.shape: # 0-dim case
81+
array[...] = self.scale * (2 * numpy.random.randint(2) - 1)
82+
elif not array.size:
83+
raise ValueError('Array to be initialized must be non-empty.')
84+
else:
85+
# numpy.prod returns float value when the argument is empty.
86+
out_dim = len(array)
87+
in_dim = size_of_shape(array.shape[1:])
88+
if (in_dim > out_dim and self._checks[0]) or (
89+
in_dim < out_dim and self._checks[1]):
90+
raise ValueError(
91+
'Cannot make orthogonal {}.'
92+
'shape = {}, interpreted as '
93+
'{}-dim input and {}-dim output.'.format(
94+
self.mode, array.shape, in_dim, out_dim))
95+
transpose = in_dim > out_dim
96+
a = numpy.random.normal(size=(out_dim, in_dim))
97+
if transpose:
98+
a = a.T
99+
# cupy.linalg.qr requires cusolver in CUDA 8+
100+
q, r = numpy.linalg.qr(a)
101+
q *= numpy.copysign(self.scale, numpy.diag(r))
102+
if transpose:
103+
q = q.T
104+
array[...] = xp.asarray(q.reshape(array.shape))

0 commit comments

Comments
 (0)