Skip to content

Commit 9e63ef1

Browse files
committed
fix style for drex unittest
1 parent ff4de47 commit 9e63ef1

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import os
3+
from easydict import EasyDict
4+
from copy import deepcopy
5+
6+
from dizoo.classic_control.cartpole.config.cartpole_dqn_config \
7+
import cartpole_dqn_config, cartpole_dqn_create_config
8+
from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config \
9+
import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config
10+
from ding.entry import serial_pipeline, serial_pipeline_reward_model_offpolicy
11+
from ding.entry.application_entry_drex_collect_data import drex_collecting_data
12+
13+
14+
@pytest.mark.unittest
15+
def test_drex():
16+
exp_name = 'test_serial_pipeline_drex_expert'
17+
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
18+
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
19+
config[0].exp_name = exp_name
20+
expert_policy = serial_pipeline(config, seed=0)
21+
22+
exp_name = 'test_serial_pipeline_drex_collect'
23+
config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)]
24+
config[0].exp_name = exp_name
25+
config[0].reward_model.exp_name = exp_name
26+
config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar'
27+
config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params'
28+
config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect'
29+
config[0].reward_model.checkpoint_max = 100
30+
config[0].reward_model.checkpoint_step = 100
31+
config[0].reward_model.num_snippets = 100
32+
33+
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
34+
args.cfg[0].policy.collect.n_episode = 8
35+
del args.cfg[0].policy.collect.n_sample
36+
args.cfg[0].bc_iteration = 1000 # for unittest
37+
args.cfg[1].policy.type = 'bc'
38+
drex_collecting_data(args=args)
39+
try:
40+
serial_pipeline_reward_model_offpolicy(
41+
config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False
42+
)
43+
except Exception:
44+
assert False, "pipeline fail"
45+
finally:
46+
os.popen('rm -rf test_serial_pipeline_drex*')

0 commit comments

Comments
 (0)