Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bloom models #73

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f21223a
bug fix
ydli-ai Nov 24, 2022
cf74ac6
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Nov 25, 2022
efc6539
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Dec 1, 2022
9fe4f32
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Dec 14, 2022
6baa170
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Jan 4, 2023
ec90e5e
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Mar 7, 2023
a1047b7
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Mar 16, 2023
3816b6b
Merge branch 'main' of https://github.com/Tencent/TencentPretrain
ydli-ai Mar 30, 2023
a0e7c8f
add lora for pretrain.
fengyh3 Apr 2, 2023
ffc42e6
add lora for decoder and fix some bug.
fengyh3 Apr 3, 2023
0f3c195
fix lora deepspeed save model.
fengyh3 Apr 4, 2023
7deba29
Merge pull request #1 from fengyh3/dev
fengyh3 Apr 4, 2023
c11f465
Dev (#5)
ydli-ai Apr 4, 2023
de2e59a
Llama v2 (#6)
ydli-ai Apr 4, 2023
c5979a8
add zero3 load lora
ydli-ai Apr 4, 2023
05a8cda
add llama config
ydli-ai Apr 4, 2023
d732942
add llama config
ydli-ai Apr 5, 2023
143c3aa
add llama config
ydli-ai Apr 5, 2023
7b92d3d
add llama config
ydli-ai Apr 5, 2023
834c1c3
add llama config
ydli-ai Apr 5, 2023
c8ca114
add llama config
ydli-ai Apr 5, 2023
6747460
add llama config
ydli-ai Apr 5, 2023
35edf4c
add llama config
ydli-ai Apr 5, 2023
81fb0e4
add llama config
ydli-ai Apr 5, 2023
c36d3ee
add llama config
ydli-ai Apr 5, 2023
dfaf694
add llama config
ydli-ai Apr 5, 2023
caa93bc
add llama config
ydli-ai Apr 5, 2023
82a2794
add llama config
ydli-ai Apr 5, 2023
eed09ee
add llama config
ydli-ai Apr 5, 2023
3142668
add llama config
ydli-ai Apr 5, 2023
1255800
Merge branch 'main' into lora
ydli-ai Apr 5, 2023
4a774f0
add llama config
ydli-ai Apr 5, 2023
e78ce16
add llama config
ydli-ai Apr 5, 2023
85600fb
add llama config
ydli-ai Apr 7, 2023
af7d1c8
add llama config
ydli-ai Apr 7, 2023
038e58c
add llama config
ydli-ai Apr 7, 2023
1f44866
add llama config
ydli-ai Apr 7, 2023
02924b7
add llama config
ydli-ai Apr 7, 2023
974ddf0
Merge pull request #2 from ydli-ai/lora
fengyh3 Apr 7, 2023
b42c0db
Merge branch 'Tencent:main' into main
fengyh3 Apr 10, 2023
cc0fa81
add bloom
fengyh3 Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions models/bloom/bloom_1b1_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"emb_size": 1536,
"feedforward_size": 6144,
"hidden_size": 1536,
"hidden_act": "gelu",
"heads_num": 16,
"layers_num": 24,
"dropout": 0.0,
"data_processor": "lm",
"embedding": ["word"],
"remove_transformer_bias": false,
"has_lmtarget_bias": false,
"remove_embedding_layernorm": false,
"encoder": "transformer",
"mask": "causal",
"layernorm_positioning": "pre",
"target": ["lm"],
"tie_weights": true,
"alibi_position_embedding": true,
"layer_number_scale": true,
"vocab_size": 250880
}
90 changes: 90 additions & 0 deletions scripts/convert_bloom_from_huggingface_to_tencentpretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
import collections
import torch
import os


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/llama-7b/",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/llama-7b.bin",
help=".")
parser.add_argument("--layers_num", type=int, required=True)
parser.add_argument("--hidden_size", type=int, required=True)
parser.add_argument("--head_num", type=int, required=True)

args = parser.parse_args()

files = os.listdir(args.input_model_path)
model_files = [f for f in files if f[-4:] == ".bin"]
output_model = collections.OrderedDict()
output_model_mapping = collections.OrderedDict()

output_model_mapping['embedding.word.embedding.weight'] = 'word_embeddings.weight'
output_model_mapping['embedding.layer_norm.gamma'] = 'word_embeddings_layernorm.weight'
output_model_mapping['embedding.layer_norm.beta'] = 'word_embeddings_layernorm.bias'

for i in range(args.layers_num):
# attention ln
output_model_mapping["encoder.transformer." + str(i) + ".layer_norm_1.gamma"] = \
"h." + str(i) + ".input_layernorm.weight"
output_model_mapping["encoder.transformer." + str(i) + ".layer_norm_1.beta"] = \
"h." + str(i) + ".input_layernorm.bias"

# attention weight
output_model_mapping["encoder.transformer." + str(i) + ".self_attn.linear_layers.weight"] = \
'h.' + str(i) + '.self_attention.query_key_value.weight'

# attention bias
output_model_mapping["encoder.transformer." + str(i) + ".self_attn.linear_layers.bias"] = \
'h.' + str(i) + '.self_attention.query_key_value.bias'

# attention output
output_model_mapping['encoder.transformer.' + str(i) + '.self_attn.final_linear.weight'] = \
'h.' + str(i) + '.self_attention.dense.weight'
output_model_mapping['encoder.transformer.' + str(i) + '.self_attn.final_linear.bias'] = \
'h.' + str(i) + '.self_attention.dense.bias'

# FFN ln
output_model_mapping["encoder.transformer." + str(i) + ".layer_norm_2.gamma"] = \
'h.' + str(i) + '.post_attention_layernorm.weight'
output_model_mapping["encoder.transformer." + str(i) + ".layer_norm_2.beta"] = \
'h.' + str(i) + '.post_attention_layernorm.bias'

# FFN
output_model_mapping['encoder.transformer.' + str(i) + '.feed_forward.linear_1.weight'] = \
'h.' + str(i) + '.mlp.dense_h_to_4h.weight'
output_model_mapping['encoder.transformer.' + str(i) + '.feed_forward.linear_1.bias'] = \
'h.' + str(i) + '.mlp.dense_h_to_4h.bias'
output_model_mapping['encoder.transformer.' + str(i) + '.feed_forward.linear_2.weight'] = \
'h.' + str(i) + '.mlp.dense_4h_to_h.weight'
output_model_mapping['encoder.transformer.' + str(i) + '.feed_forward.linear_2.bias'] = \
'h.' + str(i) + '.mlp.dense_4h_to_h.bias'

output_model_mapping['encoder.layer_norm.gamma'] = 'ln_f.weight'
output_model_mapping['encoder.layer_norm.beta'] = 'ln_f.bias'

input_model_mapping = {v: k for k, v in output_model_mapping.items()}
head_per_size = args.hidden_size // args.head_num

for f in model_files:
checkpoint = torch.load(os.path.join(args.input_model_path, f), map_location='cpu')
for name, parm in checkpoint.items():
if 'query_key_value' in name:
module_name = input_model_mapping[name].split('.')
if 'weight' in name:
parm = parm.reshape((args.head_num, head_per_size * 3, args.hidden_size))
q, k, v = torch.split(parm, head_per_size, dim=-2)
output_model['.'.join(module_name[:-1]) + '.0.' + module_name[-1]] = q.reshape((args.hidden_size, args.hidden_size))
output_model['.'.join(module_name[:-1]) + '.1.' + module_name[-1]] = k.reshape((args.hidden_size, args.hidden_size))
output_model['.'.join(module_name[:-1]) + '.2.' + module_name[-1]] = v.reshape((args.hidden_size, args.hidden_size))
else:
parm = parm.reshape((args.head_num, head_per_size * 3))
q, k, v = torch.split(parm, head_per_size, dim=-1)
output_model['.'.join(module_name[:-1]) + '.0.' + module_name[-1]] = q.reshape((args.hidden_size))
output_model['.'.join(module_name[:-1]) + '.1.' + module_name[-1]] = k.reshape((args.hidden_size))
output_model['.'.join(module_name[:-1]) + '.2.' + module_name[-1]] = v.reshape((args.hidden_size))
else:
output_model[input_model_mapping[name]] = parm

torch.save(output_model, args.output_model_path)
119 changes: 119 additions & 0 deletions scripts/generate_bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
This script provides an exmaple to wrap TencentPretrain for generation.
Given the beginning of a text, language model generates the rest.
"""
import sys
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.targets import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import *
from tencentpretrain.opts import infer_opts, tokenizer_opts, model_opts


class GenerateLm(torch.nn.Module):
def __init__(self, args):
super(GenerateLm, self).__init__()
self.embedding = Embedding(args)
for embedding_name in args.embedding:
tmp_emb = str2embedding[embedding_name](args, args.vocab_size)
self.embedding.update(tmp_emb, embedding_name)
self.encoder = str2encoder[args.encoder](args)
self.target = Target()
self.target.update(LmTarget(args, args.vocab_size), "lm")
if args.tie_weights:
self.target.lm.output_layer.weight = self.embedding.word.embedding.weight

def forward(self, src, seg):
emb = self.embedding(src, seg)
output = self.encoder(emb, seg)
output = self.target.lm.output_layer(output)
return output


def top_k_top_p_filtering(logits, top_k, top_p):
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")

if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -float("Inf")
return logits


if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

infer_opts(parser)
parser.add_argument("--tokenizer_path", default=None, type=str,
help="Path of the tokenizer path.")
parser.add_argument("--top_k", type=int, default=70)
parser.add_argument("--top_p", type=float, default=0)
parser.add_argument("--temperature", type=float, default=1.0)

args = parser.parse_args()
args.target = "lm"
args.batch_size = 1

# todo: convert to tencentpretrain tokenizer
args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
args = load_hyperparam(args)

model = GenerateLm(args)
model = load_model(model, args.load_model_path)
checkpoint = torch.load(args.load_model_path, map_location='cpu')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.encode(line)
seg = [1] * len(src)
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.seq_length - beginning_length):
with torch.no_grad():
output = model(src_tensor, seg_tensor)
next_token_logits = output[0][-1] / 1
filtered_logits = top_k_top_p_filtering(next_token_logits, 1, 0)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1)

f.write(line + "\n")
tokens = [token_id.item() for token_id in src_tensor[0]]
generated_sentence = args.tokenizer.decode(tokens)

f.write(generated_sentence)
26 changes: 18 additions & 8 deletions tencentpretrain/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import precompute_freqs_cis
from tencentpretrain.utils.alibi import build_alibi_tensor
from tencentpretrain.layers.transformer import TransformerLayer
from tencentpretrain.layers.layer_norm import *
from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding
Expand All @@ -13,11 +14,13 @@ def __init__(self, args):
super(TransformerEncoder, self).__init__()
self.mask = args.mask
self.layers_num = args.layers_num
self.heads_num = args.heads_num
self.parameter_sharing = args.parameter_sharing
self.factorized_embedding_parameterization = args.factorized_embedding_parameterization
self.layernorm_positioning = args.layernorm_positioning
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.alibi_position_embedding = args.alibi_position_embedding
self.has_residual_attention = args.has_residual_attention
if "deepspeed_checkpoint_activations" in args:
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
Expand All @@ -34,15 +37,16 @@ def __init__(self, args):
self.transformer = TransformerLayer(args)
else:
self.transformer = nn.ModuleList(
[TransformerLayer(args) for _ in range(self.layers_num)]
[TransformerLayer(args, i if args.layer_number_scale or args.alibi_position_embedding else None)
for i in range(self.layers_num)]
)
if self.layernorm_positioning == "pre":
if args.layernorm == "t5":
self.layer_norm = T5LayerNorm(args.hidden_size)
self.layer_norm = T5LayerNorm(args.hidden_size, args.eps)
elif args.layernorm == "rms":
self.layer_norm = RMSNorm(args.hidden_size)
self.layer_norm = RMSNorm(args.hidden_size, args.eps)
else:
self.layer_norm = LayerNorm(args.hidden_size)
self.layer_norm = LayerNorm(args.hidden_size, args.eps)

if self.relative_position_embedding:
self.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num,
Expand Down Expand Up @@ -107,6 +111,12 @@ def forward(self, emb, seg):
else:
freqs_cis = None

if self.alibi_position_embedding:
attention_mask = torch.ones((batch_size, seq_length), device=hidden.device)
alibi = build_alibi_tensor(attention_mask, self.heads_num, hidden.dtype, hidden.device)
else:
alibi = None

prev_attn = None

if self.deepspeed_checkpoint_activations:
Expand All @@ -119,11 +129,11 @@ def custom_forward(*inputs):
if self.parameter_sharing:
x_, y_ = self.transformer(x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
prev_attn=y_, freqs_cis=freqs_cis_, alibi=alibi)
else:
x_, y_ = self.transformer[index](x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
prev_attn=y_, freqs_cis=freqs_cis_, alibi=alibi)
return x_, y_

return custom_forward
Expand All @@ -137,11 +147,11 @@ def custom_forward(*inputs):
if self.parameter_sharing:
hidden, prev_attn = self.transformer(hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
prev_attn=prev_attn, freqs_cis=freqs_cis, alibi=alibi)
else:
hidden, prev_attn = self.transformer[i](hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
prev_attn=prev_attn, freqs_cis=freqs_cis, alibi=alibi)

if self.layernorm_positioning == "pre":
return self.layer_norm(hidden)
Expand Down
10 changes: 7 additions & 3 deletions tencentpretrain/layers/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@ class LayerNorm(nn.Module):
Layer Normalization.
https://arxiv.org/abs/1607.06450
"""
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps=1e-6, eps_inside=False):
super(LayerNorm, self).__init__()
self.eps = eps
self.eps_inside = eps_inside
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))

def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
hidden_states = self.gamma * (x-mean) / (std + self.eps)
if self.eps_inside:
std = torch.sqrt(x.var(-1, keepdim=True) + self.eps)
else:
std = x.std(-1, keepdim=True) + self.eps
hidden_states = self.gamma * (x-mean) / std

return hidden_states + self.beta

Expand Down
24 changes: 21 additions & 3 deletions tencentpretrain/layers/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MultiHeadedAttention(nn.Module):
"""

def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True,
lora_params=None):
lora_params=None, layer_number=None):
super(MultiHeadedAttention, self).__init__()
self.heads_num = heads_num

Expand All @@ -36,9 +36,13 @@ def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bia
)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
# layer-wise attention scaling
if layer_number is not None:
self.layer_number = max(1, layer_number)
self.norm_factor = math.sqrt(self.per_head_size) * self.layer_number

def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None,
freqs_cis=None):
freqs_cis=None, alibi=None):
"""
Args:
key: [batch_size x seq_length x hidden_size]
Expand Down Expand Up @@ -76,8 +80,22 @@ def unshape(x):
if position_bias is not None:
scores = scores + position_bias
if self.with_scale:
scores = scores / math.sqrt(float(per_head_size))
if self.layer_number is not None:
scores = scores * (1.0 / self.norm_factor)
else:
scores = scores / math.sqrt(float(per_head_size))
if alibi is not None:
scores = scores.reshape((-1, scores.shape[-2], scores.shape[-1]))
scores += (1.0 / self.layer_number) * alibi
scores = scores.view(-1, heads_num, scores.shape[-2], scores.shape[-1])

scores = scores + mask.type_as(scores)

# scaled softmax
if self.layer_number is not None:
scores = (scores * self.layer_number) + mask
scores = torch.max(scores, torch.tensor(-10000))

prev_attn_out = None
if has_residual_attention:
if prev_attn is not None:
Expand Down
Loading