-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_framework.py
65 lines (50 loc) · 1.87 KB
/
test_framework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import gym
import atexit
import threading
import minerl
from basalt_baselines.bc import bc_baseline, WRAPPERS as bc_wrappers
import numpy as np
from test_submission_code import KAIROS_MineRLAgent, MineRLAgent, Episode, EpisodeDone, MineRLBehavioralCloningAgent
from minerl.herobraine.wrappers import downscale_wrapper
from basalt_utils.utils import wrap_env
import torch as th
# import coloredlogs
#coloredlogs.install(logging.DEBUG)
MINERL_GYM_ENV = os.getenv('MINERL_GYM_ENV', 'MineRLBasaltFindCaveHighRes-v0')
MINERL_MAX_EVALUATION_EPISODES = int(os.getenv('MINERL_MAX_EVALUATION_EPISODES', 2))
# We only use one evaluation thread
EVALUATION_THREAD_COUNT = 1
####################
# EVALUATION CODE #
####################
def main():
# agent = MineRLBehavioralCloningAgent()
agent = KAIROS_MineRLAgent(env_name=MINERL_GYM_ENV)
agent.load_agent()
assert MINERL_MAX_EVALUATION_EPISODES > 0
assert EVALUATION_THREAD_COUNT > 0
env = gym.make(MINERL_GYM_ENV)
# Bit of sanity check
if env.observation_space['pov'].shape[0] != 1024:
raise RuntimeError('The MineRL environment should be a "HighRes" variant.')
# Apply downscale wrapper to turn (1024, 1024) observations into (64, 64)
env = downscale_wrapper.DownscaleWrapper(env)
# Ensure that videos are closed properly
@atexit.register
def cleanup_env():
env.close()
# A simple function to evaluate on episodes!
def evaluate(i, env):
print("[{}] Starting evaluator.".format(i))
for i in range(MINERL_MAX_EVALUATION_EPISODES):
try:
agent.run_agent_on_episode(Episode(env))
except EpisodeDone:
print("[{}] Episode complete".format(i))
pass
thread = threading.Thread(target=evaluate, args=(0, env))
thread.start()
thread.join()
if __name__ == "__main__":
main()