forked from traveller59/second.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathscript.py
52 lines (46 loc) · 1.77 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from second.pytorch.train import train, evaluate
from google.protobuf import text_format
from second.protos import pipeline_pb2
from pathlib import Path
from second.utils import config_tool
import warnings
warnings.filterwarnings("ignore")
def train_multi_rpn_layer_num():
config_path = "./configs/nuscenes/all.fhd.config"
model_root = Path.home() / "second_test" # don't forget to change this.
config = pipeline_pb2.TrainEvalPipelineConfig()
with open(config_path, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, config)
input_cfg = config.eval_input_reader
model_cfg = config.model.second
layer_nums = [2, 4, 7, 9]
for l in layer_nums:
model_dir = str(model_root / f"all_fhd_{l}")
model_cfg.rpn.layer_nums[:] = [l]
train(config, model_dir, resume=True)
def eval_multi_threshold():
config_path = "./configs/nuscenes/all.fhd.config"
ckpt_name = "/home/ags/second_test/all_fhd_2/" # don't forget to change this.
#assert "/path/to/your" not in ckpt_name
config = pipeline_pb2.TrainEvalPipelineConfig()
with open(config_path, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, config)
model_cfg = config.model.second
#model_cfg['nms_score_threshold'] = 0.3 ### extra added by ags
#import pdb; pdb.set_trace()
threshs = [0.3]
for thresh in threshs:
model_cfg.nms_score_threshold = thresh
# don't forget to change this.
result_path = Path.home() / f"second_test_eval_{thresh:.2f}"
evaluate(
config,
result_path=result_path,
ckpt_path=str(ckpt_name),
batch_size=1,
measure_time=True)
if __name__ == "__main__":
#eval_multi_threshold()
train_multi_rpn_layer_num()