-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassifier_trf_hf.py
128 lines (112 loc) · 4.07 KB
/
classifier_trf_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from functools import partial
import sys
from pathlib import Path
import numpy as np
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
Trainer,
EarlyStoppingCallback,
XLMRobertaForSequenceClassification,
XLMRobertaModel,
)
from data import load_corpus, load_corpus_sentence_pairs
from models import BilingualSentenceClassifier
from util import get_training_arguments, compute_metrics, parse_args_hf
def main():
"""
Train a model using the Huggingface Trainer API.
"""
# Get arguments.
args = parse_args_hf()
# Set random seed.
np.random.seed(args.seed)
# Set directories.
root_dir = Path(args.root_dir)
if args.load_model is not None: # initialize a trained model
assert Path(
args.load_model
).is_dir(), (
f"{args.load_model} is not a checkpoint directory, which it should be."
)
model_name = args.arch.replace("/", "-")
mt = "google" if args.use_google_data else "deepl"
eff_bsz = args.gradient_accumulation_steps * args.batch_size
if args.test:
mt = args.test
output_dir = (
root_dir
/ f"models/{mt}/{model_name}_lr={args.learning_rate}_bsz={eff_bsz}_epochs={args.num_epochs}_seed={args.seed}/"
)
if args.eval:
output_dir = Path(output_dir.parent) / (output_dir.name + "_eval")
elif args.test:
output_dir = Path(output_dir.parent) / (output_dir.name + "_test")
args.output_dir = output_dir
# Load the data.
idx_to_docid = None
test_or_dev = "test" if args.test else "dev"
if args.load_sentence_pairs: # load both source and translations (bilingual)
train_data = load_corpus_sentence_pairs(args, "train")
eval_data = load_corpus_sentence_pairs(args, test_or_dev)
else: # load only translations (monolingual)
train_data, _ = load_corpus(args, "train")
eval_data, idx_to_docid = load_corpus(
args, test_or_dev, split_docs_by_sentence=args.use_majority_classification
)
# Load the model.
if args.load_model is not None: # start from a trained model
print(f"Loading model at {args.load_model}")
if args.load_sentence_pairs:
model = XLMRobertaForSequenceClassification.from_pretrained(
args.load_model, local_files_only=True
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
args.load_model, local_files_only=True
)
else:
model_name = args.arch
print(f"Loading LM: {model_name}")
config = AutoConfig.from_pretrained(
model_name, num_labels=2, classifier_dropout=args.dropout
)
if args.load_sentence_pairs:
model = BilingualSentenceClassifier(
XLMRobertaModel.from_pretrained(
model_name,
config=config,
local_files_only=False,
add_pooling_layer=False,
),
config.hidden_size,
dropout=args.dropout,
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_name, config=config, local_files_only=False
)
# Setup Huggingface training arguments.
training_args = get_training_arguments(args)
# For logging purposes.
print("Generated by command:\npython", " ".join(sys.argv))
print("Logging training settings\n", training_args)
callbacks = [
EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)
]
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=eval_data,
compute_metrics=partial(compute_metrics, idx_to_docid=idx_to_docid),
callbacks=callbacks,
)
# Start training/evaluation.
if args.test or args.eval or args.use_majority_classification:
mets = trainer.evaluate()
else:
mets = trainer.train()
print("\nInfo:\n", mets, "\n")
if __name__ == "__main__":
main()