-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_generic.py
173 lines (139 loc) · 5.84 KB
/
eval_generic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# for pre_generated_attacks, override the path to files by:
# python eval_generic.py --config-name=attack_config attack=pre_generated_attack attack.obs.agentview_image.path='aaaaa'
# for online attacks, override params by:
# python eval_generic.py --config-name=attack_config attack=online_attack attack.alpha=0.001 attack.ntarget=null
# multi run:
# python eval_generic.py --config-name=attack_config --multirun '+experiment=default,transfer' target_ckpt_file_path=1,2 creates 4 runs!!!
import sys
# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
import hydra
from omegaconf import OmegaConf
import pathlib
from diffusion_policy.workspace.base_workspace import BaseWorkspace
import torch
import dill
import numpy as np
from diffusion_policy.attack.adversarial_noise import AdversarialNoise
import os
import json
import wandb
import copy
# allows arbitrary python code execution in configs using the ${eval:''} resolver
OmegaConf.register_new_resolver("eval", eval, replace=True)
@hydra.main(
version_base=None,
config_path=str(pathlib.Path(__file__).parent.joinpath(
'attack_configs','eval_attack_config'))
)
def main(cfg: OmegaConf):
print(OmegaConf.to_yaml(cfg))
# resolve immediately so all the ${now:} resolvers
# will use the same time.
OmegaConf.resolve(cfg)
# missing key check
missing_keys: set[str] = OmegaConf.missing_keys(cfg)
if 'ckpt_path' in missing_keys:
raise RuntimeError(f"missing key 'ckpt_path' in config file, pass it in as a command line argument 'ckpt_path=\"...\"'")
OmegaConf.set_struct(cfg, False)
# add and modify new fields here
# cfg.test_new_field = 'test_new_field'
OmegaConf.set_struct(cfg, True)
print(OmegaConf.to_yaml(cfg))
# return
# load checkpoint
payload = torch.load(open(cfg.ckpt.path, 'rb'), pickle_module=dill)
workspace_cfg = payload['cfg']
# complete patch path if physical attack
if cfg.attack_type == 'physical_patch':
cfg.attack.path = os.path.join(os.getcwd(), cfg.attack.path)
# TODO
# replace some components of the workspace_cfg
workspace_cfg.logging = cfg.logging
# make ddim scheduler for online attack if needed
if cfg.attack.attack_type == 'online' and cfg.attack.online_pgd.type == 'full_chain' and cfg.attack.online_pgd.use_ddim:
OmegaConf.set_struct(cfg, False)
cfg.attack.online_pgd.scheduler = copy.deepcopy(workspace_cfg.policy.noise_scheduler)
cfg.attack.online_pgd.scheduler['_target_'] = 'diffusers.schedulers.scheduling_ddim.DDIMScheduler'
# remove the variance_type field
cfg.attack.online_pgd.scheduler.pop('variance_type')
OmegaConf.set_struct(cfg, True)
elif cfg.attack.attack_type == 'online' and cfg.attack.online_pgd.type == 'full_chain' and not cfg.attack.online_pgd.use_ddim:
OmegaConf.set_struct(cfg, False)
cfg.attack.online_pgd.scheduler = None
OmegaConf.set_struct(cfg, True)
print(OmegaConf.to_yaml(cfg))
# return
cls = hydra.utils.get_class(workspace_cfg._target_)
workspace = cls(workspace_cfg)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
# get policy from workspace
policy = workspace.model
if workspace_cfg.training.use_ema:
policy = workspace.ema_model
device = torch.device(cfg.device)
policy.to(device)
policy.eval()
# lower the number of testing environments
workspace_cfg.task.env_runner.n_train = 0
workspace_cfg.task.env_runner.n_train_vis = 0
# workspace_cfg.task.env_runner.n_test = 1
# workspace_cfg.task.env_runner.n_test_vis = 1
# workspace_cfg.task.env_runner.max_steps = 8
# workspace_cfg.task.env_runner.n_envs = 1
env_runner = hydra.utils.instantiate(
workspace_cfg.task.env_runner,
output_dir=workspace.output_dir,
attack_config=cfg.attack if cfg.attack.attack_type is not None else None
)
# print(OmegaConf.to_yaml(env_runner.attack_config))
# return
noise = None
simulated_physical_attack = None
tf_dict = None
# special case for pre-generated attacks
if cfg.attack_type == 'pre_gen':
# print(cfg.attack.pre_gen.obs.items())
datadict = {key: np.load(value['path']) for key, value in cfg.attack.pre_gen.obs.items()}
noise = AdversarialNoise(datadict=datadict)
noise.to(device)
# run the attack
log = env_runner.run(policy, noise=noise, simulated_physical_attack=simulated_physical_attack, tf_dict=tf_dict)
# log = env_runner.run_with_ntarget(policy, 1)
# log = env_runner.run_stay_in_place(policy)
if cfg.log_wandb:
cfg = copy.deepcopy(cfg)
OmegaConf.set_struct(cfg, False)
cfg.ckpt_config = workspace_cfg # also log the original training config
OmegaConf.set_struct(cfg, True)
# configure logging
wandb_run = wandb.init(
dir=str(workspace.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": workspace.output_dir,
}
)
# log the results
wandb.log(log)
# dump log to json
json_log = dict()
for key, value in log.items():
if isinstance(value, wandb.sdk.data_types.video.Video):
json_log[key] = value._path
else:
json_log[key] = value
# append the attack config to the log
json_log['cfg'] = OmegaConf.to_container(cfg)
out_path = os.path.join(workspace.output_dir, 'eval_log.json')
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
# import time
# time.sleep(30)
print("done")
if __name__ == "__main__":
main()