Skip to content

Commit c0a6fbb

Browse files
committed
Fix typo: stocahstic -> stochastic
1 parent 54d9db5 commit c0a6fbb

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

gailtf/algo/behavior_clone.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from common.statistics import stats
88
import ipdb
99

10-
def evaluate(env, policy_func, load_model_path, stocahstic_policy=False, number_trajs=10):
10+
def evaluate(env, policy_func, load_model_path, stochastic_policy=False, number_trajs=10):
1111
from algo.trpo_mpi import traj_episode_generator
1212
ob_space = env.observation_space
1313
ac_space = env.action_space
@@ -16,7 +16,7 @@ def evaluate(env, policy_func, load_model_path, stocahstic_policy=False, number_
1616
ob = U.get_placeholder_cached(name="ob")
1717
ac = pi.pdtype.sample_placeholder([None])
1818
stochastic = U.get_placeholder_cached(name="stochastic")
19-
ep_gen = traj_episode_generator(pi, env, 1024, stochastic=stocahstic_policy)
19+
ep_gen = traj_episode_generator(pi, env, 1024, stochastic=stochastic_policy)
2020
U.load_state(load_model_path)
2121
len_list = []
2222
ret_list = []
@@ -25,7 +25,7 @@ def evaluate(env, policy_func, load_model_path, stocahstic_policy=False, number_
2525
ep_len, ep_ret = traj['ep_len'], traj['ep_ret']
2626
len_list.append(ep_len)
2727
ret_list.append(ep_ret)
28-
if stocahstic_policy:
28+
if stochastic_policy:
2929
print ('stochastic policy:')
3030
else:
3131
print ('deterministic policy:' )

gailtf/algo/trpo_mpi.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def traj_episode_generator(pi, env, horizon, stochastic):
385385
t += 1
386386

387387
def evaluate(env, policy_func, load_model_path, timesteps_per_batch, number_trajs=10,
388-
stocahstic_policy=False):
388+
stochastic_policy=False):
389389

390390
from tqdm import tqdm
391391
# Setup network
@@ -396,7 +396,7 @@ def evaluate(env, policy_func, load_model_path, timesteps_per_batch, number_traj
396396
U.initialize()
397397
# Prepare for rollouts
398398
# ----------------------------------------
399-
ep_gen = traj_episode_generator(pi, env, timesteps_per_batch, stochastic=stocahstic_policy)
399+
ep_gen = traj_episode_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy)
400400
U.load_state(load_model_path)
401401

402402
len_list = []
@@ -406,7 +406,7 @@ def evaluate(env, policy_func, load_model_path, timesteps_per_batch, number_traj
406406
ep_len, ep_ret = traj['ep_len'], traj['ep_ret']
407407
len_list.append(ep_len)
408408
ret_list.append(ep_ret)
409-
if stocahstic_policy:
409+
if stochastic_policy:
410410
print ('stochastic policy:')
411411
else:
412412
print ('deterministic policy:' )

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def argsparser():
2020
# Task
2121
parser.add_argument('--task', type=str, choices=['train', 'evaluate'], default='train')
2222
# for evaluatation
23-
parser.add_argument('--stocahstic_policy', type=bool, default=False)
23+
parser.add_argument('--stochastic_policy', type=bool, default=False)
2424
# Mujoco Dataset Configuration
2525
parser.add_argument('--ret_threshold', help='the return threshold for the expert trajectories', type=int, default=0)
2626
parser.add_argument('--traj_limitation', type=int, default=np.inf)
@@ -79,7 +79,7 @@ def policy_fn(name, ob_space, ac_space, reuse=False):
7979
# Pretrain with behavior cloning
8080
from gailtf.algo import behavior_clone
8181
if args.algo == 'bc' and args.task == 'evaluate':
82-
behavior_clone.evaluate(env, policy_fn, args.load_model_path, stocahstic_policy=args.stocahstic_policy)
82+
behavior_clone.evaluate(env, policy_fn, args.load_model_path, stochastic_policy=args.stochastic_policy)
8383
sys.exit()
8484
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
8585
max_iters=args.BC_max_iter, pretrained=args.pretrained,
@@ -114,7 +114,7 @@ def policy_fn(name, ob_space, ac_space, reuse=False):
114114
task_name=task_name)
115115
elif args.task == 'evaluate':
116116
trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024,
117-
number_trajs=10, stocahstic_policy=args.stocahstic_policy)
117+
number_trajs=10, stochastic_policy=args.stochastic_policy)
118118
else: raise NotImplementedError
119119
else: raise NotImplementedError
120120

0 commit comments

Comments
 (0)