-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_predict.py
144 lines (122 loc) · 4.95 KB
/
batch_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import json
import random
import torch
from tqdm import tqdm
from transformers import (
HfArgumentParser,
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
)
from utils import ModelArguments, DataPredictionArguments, read_jsonl
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataPredictionArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
# Load pretrained tokenizer and config
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, **tokenizer_kwargs
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model.to(device)
# Load prediction data
data = read_jsonl(data_args.input_file)
if data_args.max_predict_samples is not None:
data = random.sample(data, data_args.max_predict_samples)
output_data = []
def encode_data_prompt(data_point):
result = []
full_name = data_point["full_name"]
birthday = data_point["birthday"]
city = data_point["city"]
major = data_point["major"]
university = data_point["university"]
company = data_point["company"]
prompt = f"{full_name} was born on"
result.append((prompt, birthday))
prompt = f"{full_name} lives in the city of"
result.append((prompt, city))
prompt = f"{full_name} majors in the field of"
result.append((prompt, major))
prompt = f"{full_name} graduates from the"
result.append((prompt, university))
prompt = f"{full_name} works for the company of"
result.append((prompt, company))
return result
encoded_data = []
for data_point in data:
for prompt, answer in encode_data_prompt(data_point):
encoded_data.append((prompt, answer, data_point))
# Predict
batch_size = data_args.batch_size
output_data = []
for i in tqdm(range(0, len(encoded_data), batch_size)):
batch = encoded_data[i : i + batch_size]
prompts = [prompt for prompt, _, _ in batch]
answers = [answer for _, answer, _ in batch]
data_points = [data_point for _, _, data_point in batch]
inputs = tokenizer.batch_encode_plus(
prompts, return_tensors="pt", padding=True, truncation=True
).to(device)
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=data_args.max_new_tokens,
num_return_sequences=data_args.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
# do_sample=True,
# top_k=5,
)
for j in range(len(outputs)):
output = tokenizer.decode(outputs[j], skip_special_tokens=True)
truncated_output = output[len(prompts[j]) :]
prediction = truncated_output.split(".")[0].strip()
output_data.append(
{
"correct": 1 if prediction == answers[j] else 0,
"ground_truth": answers[j],
"prediction": prediction,
"output": output,
**data_points[j],
}
)
# Calculate accuracy
correct_count = sum([data_point["correct"] for data_point in output_data])
accuracy = correct_count / len(output_data)
print(f"Accuracy: {accuracy}")
# Save prediction data
output_file = os.path.join(model_args.model_name_or_path, f"batch_prediction.jsonl")
with open(output_file, "w", encoding="utf-8") as file:
for data_point in output_data:
file.write(json.dumps(data_point, ensure_ascii=False) + "\n")
print(f"Prediction data saved to {output_file}")
if __name__ == "__main__":
main()