Skip to content

Commit

Permalink
Support for training on MADRAS env
Browse files Browse the repository at this point in the history
  • Loading branch information
buridiaditya committed Jan 5, 2019
1 parent 198bbed commit 340ac6b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
3 changes: 2 additions & 1 deletion baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def learn(
network,
env,
seed=None,
nsteps=5,
nsteps=20,
total_timesteps=int(80e6),
vf_coef=0.5,
ent_coef=0.01,
Expand Down Expand Up @@ -187,6 +187,7 @@ def learn(
nenvs = env.num_envs
policy = build_policy(env, network, **network_kwargs)

print('Parallel %d number'%(nenvs))
# Instantiate the model object (that creates step_model and train_model)
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
Expand Down
3 changes: 2 additions & 1 deletion baselines/a2c/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def run(self):
self.dones = dones
for n, done in enumerate(dones):
if done:
self.obs[n] = self.obs[n]*0
# self.obs[n] = self.obs[n]*0
pass
self.obs = obs
mb_rewards.append(rewards)
mb_dones.append(self.dones)
Expand Down
14 changes: 14 additions & 0 deletions baselines/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def network_fn(X):

return network_fn

@register("mynn")
def mynn(num_layers=2, num_hidden=[300,500], activation=tf.tanh, layer_norm=False):
def network_fn(X):
h = tf.layers.flatten(X)
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden[i], init_scale=np.sqrt(2))
if layer_norm:
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
h = activation(h)

return h

return network_fn


@register("cnn")
def cnn(**conv_kwargs):
Expand Down
5 changes: 5 additions & 0 deletions baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def reset(self):
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])

def reset_envno(self,no):
self._assert_not_closed()
self.remotes[no].send(('reset', None))
return self.remotes[no].recv()

def close_extras(self):
self.closed = True
if self.waiting:
Expand Down
10 changes: 6 additions & 4 deletions baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'SpaceInvaders-Snes',
}

_game_envs['madras'] = {'gym-torcs-v0','gym-madras-v0'}
_game_envs['madras'] = {'Madras-v0'}
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
print('env_type: {}'.format(env_type))
Expand Down Expand Up @@ -88,6 +88,7 @@ def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
print('Found %d CPUs'%(nenv))
alg = args.alg
seed = args.seed

Expand Down Expand Up @@ -196,23 +197,24 @@ def main(args):
rank = MPI.COMM_WORLD.Get_rank()

model, env = train(args, extra_args)
env.close()
# env.close()

if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)

if args.play:
logger.log("Running trained model")
env = build_env(args)
# env = build_env(args)
obs = env.reset()
def initialize_placeholders(nlstm=128,**kwargs):
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
state, dones = initialize_placeholders(**extra_args)
while True:
actions, _, state, _ = model.step(obs,S=state, M=dones)
# actions, _, state, _ = model.step(obs)
obs, _, done, _ = env.step(actions)
env.render()
# env.render()
done = done.any() if isinstance(done, np.ndarray) else done

if done:
Expand Down

0 comments on commit 340ac6b

Please sign in to comment.