Skip to content

Commit 0e32385

Browse files
committed
finetuning
1 parent b9c54cc commit 0e32385

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed

finetuning/run_finetuning.py

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import argparse
2+
parser = argparse.ArgumentParser()
3+
parser.add_argument("--MODEL_NAME")
4+
parser.add_argument("--FIXED", action="store_true")
5+
parser.add_argument("--TASK")
6+
parser.add_argument("--MAX_LENGTH", type=int)
7+
parser.add_argument("--BATCH_SIZE", type=int)
8+
parser.add_argument("--EPOCHS", type=int)
9+
parser.add_argument("--GPU", default=0, type=int)
10+
args = parser.parse_args()
11+
12+
MODEL_NAME = args.MODEL_NAME
13+
FIXED = args.FIXED
14+
TASK = args.TASK
15+
NUM_TRAIN_EPOCHS = args.EPOCHS
16+
MAX_LENGTH = args.MAX_LENGTH
17+
PER_DEVICE_BATCH_SIZE = args.BATCH_SIZE
18+
SELECTED_GPU = args.GPU
19+
20+
# SELECTED_GPU = 0
21+
# MODEL_NAME = 'bert'
22+
# FIXED = False
23+
# TASK = "NA"
24+
# MAX_LENGTH = 32
25+
# NUM_TRAIN_EPOCHS = 5
26+
# PER_DEVICE_BATCH_SIZE = 64
27+
28+
INPUT_MASKING = True
29+
MLM = True
30+
LEARNING_RATE = 3e-5
31+
LR_SCHEDULER_TYPE = "linear"
32+
WARMUP_RATIO = 0.1
33+
SEED = 42
34+
SAVED_MODEL_PATH = f"/home/hmohebbi/Projects/ValueZeroing/directory/models/{MODEL_NAME}/{TASK}/"
35+
36+
# Import Packages
37+
import sys, os
38+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(sys.modules[__name__].__file__), "..")))
39+
import numpy as np
40+
import matplotlib.pyplot as plt
41+
from tqdm.auto import tqdm
42+
43+
import torch
44+
from torch.utils.data import DataLoader
45+
from torch.nn import CrossEntropyLoss
46+
47+
from utils.utils import PREPROCESS_FUNC, MODEL_PATH, NUM_LABELS, BLIMP_TASKS
48+
49+
from datasets import (
50+
load_dataset,
51+
load_from_disk,
52+
load_metric,
53+
)
54+
from modeling.customized_modeling_bert import BertForMaskedLM
55+
# from modeling.customized_modeling_roberta import RobertaForMaskedLM
56+
# from modeling.customized_modeling_electra import ElectraForMaskedLM
57+
from transformers import (
58+
AutoConfig,
59+
AutoTokenizer,
60+
AdamW,
61+
get_scheduler,
62+
default_data_collator,
63+
set_seed,
64+
)
65+
set_seed(SEED)
66+
67+
if not os.path.exists(SAVED_MODEL_PATH):
68+
os.makedirs(SAVED_MODEL_PATH)
69+
70+
# GPU
71+
if torch.cuda.is_available():
72+
device = torch.device(f"cuda:{SELECTED_GPU}")
73+
print('We will use the GPU:', torch.cuda.get_device_name(SELECTED_GPU))
74+
else:
75+
device = torch.device("cpu")
76+
print('No GPU available, using the CPU instead.')
77+
# exit()
78+
79+
# Load Dataset
80+
if TASK in BLIMP_TASKS:
81+
data_path = f"/home/hmohebbi/Projects/ValueZeroing/data/processed_blimp/{MODEL_NAME}/{TASK}"
82+
data = load_from_disk(data_path)
83+
train_data = data['train']
84+
eval_data = data['test']
85+
else:
86+
print("Not implemented yet!")
87+
exit()
88+
train_data = train_data.shuffle(SEED)
89+
num_labels = NUM_LABELS[TASK]
90+
91+
# Download Tokenizer & Model
92+
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME], num_labels=num_labels)
93+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
94+
95+
if MODEL_NAME == "bert":
96+
model = BertForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
97+
# elif MODEL_NAME == "roberta":
98+
# model = RobertaForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
99+
# elif MODEL_NAME == "electra":
100+
# model = ElectraForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
101+
else:
102+
print("model doesn't exist")
103+
exit()
104+
105+
model.to(device)
106+
107+
# Preprocessing
108+
train_dataset = PREPROCESS_FUNC[TASK](train_data, tokenizer, MAX_LENGTH, input_masking=INPUT_MASKING, mlm=MLM)
109+
eval_dataset = PREPROCESS_FUNC[TASK](eval_data, tokenizer, MAX_LENGTH, input_masking=INPUT_MASKING, mlm=MLM)
110+
111+
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn= default_data_collator, batch_size=PER_DEVICE_BATCH_SIZE)
112+
eval_dataloader = DataLoader(eval_dataset, collate_fn= default_data_collator, batch_size=PER_DEVICE_BATCH_SIZE)
113+
114+
num_update_steps_per_epoch = len(train_dataloader)
115+
max_train_steps = NUM_TRAIN_EPOCHS * num_update_steps_per_epoch
116+
117+
# Optimizer
118+
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
119+
lr_scheduler = get_scheduler(
120+
name=LR_SCHEDULER_TYPE,
121+
optimizer=optimizer,
122+
num_warmup_steps=WARMUP_RATIO * max_train_steps,
123+
num_training_steps=max_train_steps,
124+
)
125+
126+
# metric & Loss
127+
metric = load_metric("accuracy")
128+
loss_fct = CrossEntropyLoss()
129+
130+
tag = "forseqclassification_"
131+
tag += "pretrained" if FIXED else "finetuned"
132+
if MLM:
133+
tag += "_MLM"
134+
135+
# Train
136+
progress_bar = tqdm(range(max_train_steps))
137+
completed_steps = 0
138+
for epoch in range(NUM_TRAIN_EPOCHS):
139+
# Train
140+
model.train()
141+
for batch in train_dataloader:
142+
good_token_id = batch.pop('good_token_id').to(device)
143+
bad_token_id = batch.pop('bad_token_id').to(device)
144+
batch = {k: v.to(device) for k, v in batch.items()}
145+
outputs = model(**batch)
146+
logits = outputs.logits
147+
148+
good_logits = logits[torch.arange(logits.size(0)), good_token_id]
149+
bad_logits = logits[torch.arange(logits.size(0)), bad_token_id]
150+
logits_of_interest = torch.stack([good_logits, bad_logits], dim=1)
151+
labels = torch.zeros(logits_of_interest.shape[0], dtype=torch.int64, device=device)
152+
loss = loss_fct(logits_of_interest, labels)
153+
154+
loss.backward()
155+
optimizer.step()
156+
lr_scheduler.step()
157+
optimizer.zero_grad()
158+
progress_bar.update(1)
159+
completed_steps += 1
160+
161+
162+
model.eval()
163+
for batch in eval_dataloader:
164+
if MLM:
165+
good_token_id = batch.pop('good_token_id').to(device)
166+
bad_token_id = batch.pop('bad_token_id').to(device)
167+
batch = {k: v.to(device) for k, v in batch.items()}
168+
with torch.no_grad():
169+
outputs = model(**batch)
170+
logits = outputs.logits
171+
172+
if MLM:
173+
good_logits = logits[torch.arange(logits.size(0)), good_token_id]
174+
bad_logits = logits[torch.arange(logits.size(0)), bad_token_id]
175+
logits_of_interest = torch.stack([good_logits, bad_logits], dim=1)
176+
labels = torch.zeros(logits_of_interest.shape[0], dtype=torch.int64, device=device)
177+
predictions = torch.argmax(logits_of_interest, dim=-1)
178+
metric.add_batch(predictions=predictions, references=labels)
179+
else:
180+
predictions = torch.argmax(logits, dim=-1)
181+
metric.add_batch(predictions=predictions, references=batch['labels'])
182+
183+
eval_metric = metric.compute()
184+
print(f"epoch {epoch}: {eval_metric}")
185+
186+
187+
# Save
188+
torch.save(model.state_dict(), f'{SAVED_MODEL_PATH}full_{tag}.pt')

0 commit comments

Comments
 (0)