-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodeling.py
177 lines (140 loc) · 5.3 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import functools
from enum import Enum, auto
import torch
import torch.nn as nn
import torch.optim as optim
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from atorch.common.util_func import data_to_device
class ModelType(Enum):
TOY = auto()
GPT2 = auto()
LLAMA = auto()
def get_model_type(name):
return getattr(ModelType, name.upper(), None)
class ToyModel(nn.Module):
def __init__(self, in_features=16, out_features=4, num_linears=8):
super().__init__()
self.first_linear = nn.Linear(in_features, out_features)
self.linears = torch.nn.ModuleList([nn.Linear(out_features, out_features) for _ in range(num_linears - 1)])
def forward(self, inputs):
res = self.first_linear(inputs["input"])
for op in self.linears:
res = op(res)
return res
class MyGPT2Model(GPT2Model):
def __init__(self, config):
super().__init__(config)
self.config = config
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
transformer_outputs = super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
return lm_logits
def vocab_size(self):
return self.config.vocab_size
def optim_func(model_parameters, **kwargs):
return optim.Adam(model_parameters, **kwargs)
def toy_loss_func(inputs, output):
loss = nn.MSELoss()
return loss(inputs["label"], output)
def gpt2_loss_func(inputs, output, vocab_size):
shift_logits = output[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()
criterion = torch.nn.CrossEntropyLoss()
return criterion(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
def llama_loss_func(inputs, output):
return output.loss
def prepare_input(data, device):
return data_to_device(data, device)
def get_model(model_type, config):
# config: dict with hidden_size, head_num, layer_num, seq_length for llms
if model_type == ModelType.TOY:
model = ToyModel(
in_features=config["in_features"], out_features=config["out_features"], num_linears=config["num_linears"]
)
return model
# llms
hidden_size = config["hidden_size"]
head_num = config["head_num"]
layer_num = config["layer_num"]
seq_length = config["seq_length"]
if model_type == ModelType.GPT2:
model_config = GPT2Config()
c_s = f"n_embd={hidden_size},n_head={head_num},n_layer={layer_num},n_positions={seq_length}"
model_config.update_from_string(c_s)
model = MyGPT2Model(model_config)
elif model_type == ModelType.LLAMA:
model_config = LlamaConfig()
c_s = f"hidden_size={hidden_size},num_attention_heads={head_num},num_hidden_layers={layer_num},"
c_s += f"num_key_value_heads={head_num},max_position_embeddings={seq_length}"
model_config.update_from_string(c_s)
model = LlamaForCausalLM(model_config)
return model
def get_module_type(model_type):
if model_type == ModelType.TOY:
return nn.Linear
if model_type == ModelType.GPT2:
return GPT2Block
if model_type == ModelType.LLAMA:
return LlamaDecoderLayer
return None
def get_model_input_format(model_type):
# get model input format: "unpack_sequence", "unpack_dict", or None.
if model_type == ModelType.TOY:
return None
if model_type == ModelType.GPT2:
return "unpack_dict"
if model_type == ModelType.LLAMA:
return "unpack_dict"
return None
def get_vocab_size(model_type):
size = 0
if model_type == ModelType.GPT2:
config = GPT2Config()
size = config.vocab_size
if model_type == ModelType.LLAMA:
config = LlamaConfig()
size = config.vocab_size
return size
def get_loss_func(model_type):
if model_type == ModelType.TOY:
return toy_loss_func
if model_type == ModelType.GPT2:
vocab_size = get_vocab_size(model_type)
return functools.partial(gpt2_loss_func, vocab_size=vocab_size)
if model_type == ModelType.LLAMA:
return llama_loss_func