|
| 1 | +import pytest |
| 2 | +import os |
| 3 | +from easydict import EasyDict |
| 4 | +from copy import deepcopy |
| 5 | + |
| 6 | +from dizoo.classic_control.cartpole.config.cartpole_dqn_config \ |
| 7 | +import cartpole_dqn_config, cartpole_dqn_create_config |
| 8 | +from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config \ |
| 9 | +import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config |
| 10 | +from ding.entry import serial_pipeline, serial_pipeline_reward_model_offpolicy |
| 11 | +from ding.entry.application_entry_drex_collect_data import drex_collecting_data |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.unittest |
| 15 | +def test_drex(): |
| 16 | + exp_name = 'test_serial_pipeline_drex_expert' |
| 17 | + config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] |
| 18 | + config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 |
| 19 | + config[0].exp_name = exp_name |
| 20 | + expert_policy = serial_pipeline(config, seed=0) |
| 21 | + |
| 22 | + exp_name = 'test_serial_pipeline_drex_collect' |
| 23 | + config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)] |
| 24 | + config[0].exp_name = exp_name |
| 25 | + config[0].reward_model.exp_name = exp_name |
| 26 | + config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar' |
| 27 | + config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params' |
| 28 | + config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect' |
| 29 | + config[0].reward_model.checkpoint_max = 100 |
| 30 | + config[0].reward_model.checkpoint_step = 100 |
| 31 | + config[0].reward_model.num_snippets = 100 |
| 32 | + |
| 33 | + args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) |
| 34 | + args.cfg[0].policy.collect.n_episode = 8 |
| 35 | + del args.cfg[0].policy.collect.n_sample |
| 36 | + args.cfg[0].bc_iteration = 1000 # for unittest |
| 37 | + args.cfg[1].policy.type = 'bc' |
| 38 | + drex_collecting_data(args=args) |
| 39 | + try: |
| 40 | + serial_pipeline_reward_model_offpolicy( |
| 41 | + config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False |
| 42 | + ) |
| 43 | + except Exception: |
| 44 | + assert False, "pipeline fail" |
| 45 | + finally: |
| 46 | + os.popen('rm -rf test_serial_pipeline_drex*') |
0 commit comments