@@ -20,7 +20,7 @@ def argsparser():
20
20
# Task
21
21
parser .add_argument ('--task' , type = str , choices = ['train' , 'evaluate' ], default = 'train' )
22
22
# for evaluatation
23
- parser .add_argument ('--stocahstic_policy ' , type = bool , default = False )
23
+ parser .add_argument ('--stochastic_policy ' , type = bool , default = False )
24
24
# Mujoco Dataset Configuration
25
25
parser .add_argument ('--ret_threshold' , help = 'the return threshold for the expert trajectories' , type = int , default = 0 )
26
26
parser .add_argument ('--traj_limitation' , type = int , default = np .inf )
@@ -79,7 +79,7 @@ def policy_fn(name, ob_space, ac_space, reuse=False):
79
79
# Pretrain with behavior cloning
80
80
from gailtf .algo import behavior_clone
81
81
if args .algo == 'bc' and args .task == 'evaluate' :
82
- behavior_clone .evaluate (env , policy_fn , args .load_model_path , stocahstic_policy = args .stocahstic_policy )
82
+ behavior_clone .evaluate (env , policy_fn , args .load_model_path , stochastic_policy = args .stochastic_policy )
83
83
sys .exit ()
84
84
pretrained_weight = behavior_clone .learn (env , policy_fn , dataset ,
85
85
max_iters = args .BC_max_iter , pretrained = args .pretrained ,
@@ -114,7 +114,7 @@ def policy_fn(name, ob_space, ac_space, reuse=False):
114
114
task_name = task_name )
115
115
elif args .task == 'evaluate' :
116
116
trpo_mpi .evaluate (env , policy_fn , args .load_model_path , timesteps_per_batch = 1024 ,
117
- number_trajs = 10 , stocahstic_policy = args .stocahstic_policy )
117
+ number_trajs = 10 , stochastic_policy = args .stochastic_policy )
118
118
else : raise NotImplementedError
119
119
else : raise NotImplementedError
120
120
0 commit comments