-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
130 lines (99 loc) · 5.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import torch
import configparser
import argparse
from tabulate import tabulate
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments
config = configparser.ConfigParser()
config.read('config.ini')
output_dir = config.get('training', 'output_dir')
num_train_epochs = config.getint('training', 'num_train_epochs')
per_device_train_batch_size = config.getint('training', 'per_device_train_batch_size')
per_device_eval_batch_size = config.getint('training', 'per_device_eval_batch_size')
warmup_steps = config.getint('training', 'warmup_steps')
weight_decay = config.getfloat('training', 'weight_decay')
logging_dir = config.get('training', 'logging_dir')
logging_steps = config.getint('training', 'logging_steps')
eval_strategy = config.get('training', 'eval_strategy')
save_strategy = config.get('training', 'save_strategy')
learning_rate = config.getfloat('training', 'learning_rate')
model_name = config.get('model', 'model_name')
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro',zero_division=0)
acc = accuracy_score(labels, preds)
return {
'Accuracy': acc,
'F1': f1,
'Precision': precision,
'Recall': recall
}
class HealthDataLoader(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item['labels'] = self.labels[idx]
return item
def __len__(self):
return len(self.labels)
def load_data(preprocess_path):
"""
Loads preprocessed tokenized datasets from specified file paths and
returns them as instances of the HealthDataLoader.
This function loads the train, validation, and test datasets from
their respective tokenized files, constructs HealthDataLoader
instances, and returns them.
Args:
preprocess_path (str): Path to the folder containing the tokenized
data files (train_tokens.pt, validation_tokens.pt).
Returns:
tuple: Three HealthDataLoader instances for train, validation, and test datasets.
"""
train_data = torch.load(os.path.join(preprocess_path, 'train_tokens.pt'),weights_only=True)
validation_data = torch.load(os.path.join(preprocess_path, 'validation_tokens.pt'),weights_only=True)
test_data = torch.load(os.path.join('./preprocess', 'test_tokens.pt'),weights_only=True)
train_dataset = HealthDataLoader({'input_ids': train_data['input_ids'], 'attention_mask': train_data['attention_mask']}, train_data['labels'])
validation_dataset = HealthDataLoader({'input_ids': validation_data['input_ids'], 'attention_mask': validation_data['attention_mask']}, validation_data['labels'])
test_dataset = HealthDataLoader({'input_ids': test_data['input_ids'], 'attention_mask': test_data['attention_mask']}, test_data['labels'])
return train_dataset, validation_dataset, test_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train DistilBert Model')
parser.add_argument('--preprocess-path', type=str, default='./preprocess', help='Path to retrieve preprocessed tokens(default: ./preprocess)')
parser.add_argument('--store-model', type=str, default='./models', help='Path to save trained weight & tokenizer(default: ./models)')
args = parser.parse_args()
model_path = args.store_model
train_dataset, validation_dataset, test_dataset = load_data(args.preprocess_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_name,clean_up_tokenization_spaces=True)
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=4,ignore_mismatched_sizes=True)
model.resize_token_embeddings(len(tokenizer))
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir=logging_dir,
logging_steps=logging_steps,
eval_strategy=eval_strategy,
save_strategy=save_strategy,
learning_rate=learning_rate
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
tokenizer=tokenizer,
compute_metrics= compute_metrics
)
trainer.train()
test_result = trainer.evaluate(test_dataset)
table_data = [[key, round(value,3)] for key, value in test_result.items()]
print(tabulate(table_data, headers=["Test Eval Metric", "Value"], tablefmt="grid"))
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)