Skip to content

Commit

Permalink
wip: working on better attention
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Jul 10, 2024
1 parent 010b915 commit 6b9e1b6
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 122,744 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
.env
__pycache__
/wandb
.ipynb_checkpoints
.ipynb_checkpoints
train_tokenizer_text.txt
143 changes: 143 additions & 0 deletions attention.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9145ba44-c6d3-4ac1-874e-f54e110c15cf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from supervoice_valle import Attend\n",
"from torch.nn.attention import SDPBackend, sdpa_kernel\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c585a532-bf54-4abf-b306-5d820151acd5",
"metadata": {},
"outputs": [],
"source": [
"a_pt = Attend(engine = \"torch\", heads = 32).to(\"cuda\").eval()\n",
"a_nt = Attend(engine = \"direct\", heads = 32).to(\"cuda\").eval()\n",
"a_xt = Attend(engine = \"xformers\", heads = 32).to(\"cuda\").eval()\n",
"a_ft = Attend(engine = \"flash\", heads = 32).to(\"cuda\").eval()\n",
"attentions = [a_pt, a_nt, a_xt, a_ft]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bfdeb604-64a0-4f94-9127-9700ba09c9ac",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Without padding\n",
"torch 0.0\n",
"direct 0.00048828125\n",
"xformers 0.0\n",
"flash 0.0\n",
"With padding\n",
"torch 0.0\n",
"direct 0.00048828125\n",
"xformers 0.000244140625\n",
"flash 0.000244140625\n"
]
}
],
"source": [
"# Source\n",
"query = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n",
"key = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n",
"value = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n",
"lengths = [4, 8, 8, 12]\n",
"\n",
"print(\"Without padding\")\n",
"source = a_pt(query, key, value)\n",
"for a in attentions:\n",
" dest = a(query, key, value)\n",
" print(a.engine, (dest - source).abs().max().item())\n",
"\n",
"print(\"With padding\")\n",
"source = a_pt(query, key, value, lenghts = lengths)\n",
"for a in attentions:\n",
" dest = a(query, key, value, lenghts = lengths)\n",
" print(a.engine, (dest - source).abs().max().item())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "039f7197-074f-4776-a5e7-8c2b5eaabe4b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Without padding\n",
"torch 0.4039938449859619\n",
"direct 0.9630444049835205\n",
"xformers 0.9124343395233154\n",
"flash 0.3562278747558594\n",
"With padding\n",
"torch 1.6574337482452393\n",
"direct 2.504969835281372\n",
"xformers 1.5768730640411377\n",
"flash 1.4189743995666504\n"
]
}
],
"source": [
"# Benchmark\n",
"print(\"Without padding\")\n",
"for a in attentions:\n",
" start = time.time()\n",
" for i in range(10000):\n",
" a(query, key, value)\n",
" print(a.engine, time.time() - start)\n",
"\n",
"print(\"With padding\")\n",
"for a in attentions:\n",
" start = time.time()\n",
" for i in range(10000):\n",
" a(query, key, value, lenghts = lengths)\n",
" print(a.engine, time.time() - start)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05aa22c4-38ca-48e2-8402-f7de96c9be17",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Binary file added mkbhd.m4a
Binary file not shown.
117 changes: 90 additions & 27 deletions supervoice_valle/attention.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,128 @@
import torch
from torch import nn
import math
from einops import rearrange
import xformers.ops as xops
from flash_attn import flash_attn_func, flash_attn_varlen_func

class Attend(nn.Module):
def __init__(self, *, heads, dropout = 0., engine = "direct"):
class Attend(torch.nn.Module):
def __init__(self, *, heads, engine = "direct"):
super().__init__()
self.heads = heads
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.engine = engine

def forward(self, q, k, v, padding_mask = None):
def forward(self, q, k, v, lenghts = None):

# Check argument shapes
assert q.dim() == 4
assert k.dim() == 4
assert v.dim() == 4
assert q.size(0) == k.size(0) == v.size(0), "Batch size mismatch"
assert q.size(1) == k.size(1) == v.size(1) == self.heads, "Heads length mismatch"
assert q.size(2) == k.size(2) == v.size(2), "Sequence length mismatch"
assert q.size(1) == k.size(1) == v.size(1), "Sequence length mismatch"
assert q.size(2) == k.size(2) == v.size(2) == self.heads, "Heads length mismatch"
assert q.size(3) == k.size(3) == v.size(3), "Embeddings dimensions mismatch"
if padding_mask is not None:
assert padding_mask.dim() == 2
assert padding_mask.size(0) == q.size(0)
assert padding_mask.size(2) == q.size(2)
if lenghts is not None:
assert sum(lenghts) == q.size(1)

# Check padding mask
if self.engine == "direct":
return self.direct_attention(q, k, v, padding_mask)
return self.direct_attention(q, k, v, lenghts)
elif self.engine == "torch":
return self.pytorch_attention(q, k, v, padding_mask)
return self.pytorch_attention(q, k, v, lenghts)
elif self.engine == "xformers":
return self.xformers_attention(q, k, v, lenghts)
elif self.engine == "flash":
return self.flash_attention(q, k, v, lenghts)
else:
raise ValueError("Invalid engine")

def pytorch_attention(self, q, k, v, padding_mask):
(B, H, L, E) = q.size()
def flash_attention(self, q, k, v, lengths):
(B, L, H, E) = q.size()

# With lengths
if lengths is not None:

# Max lengths
max_len = torch.tensor(max(lengths), dtype = q.dtype, device = q.device)

# Seq lens
seqlens = [0]
last = 0
for l in lengths:
last += l
seqlens.append(last)
seqlens = torch.tensor(seqlens, dtype = torch.int32, device = q.device)

return flash_attn_varlen_func(q.squeeze(0), k.squeeze(0), v.squeeze(0), seqlens, seqlens, max_len, max_len).unsqueeze(0)

# Non length
return flash_attn_func(q, k, v)


def xformers_attention(self, q, k, v, lenghts):
(B, L, H, E) = q.size()

# Attention bias
attn_bias = None
if lenghts is not None:
attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(lenghts)

# Calcualte output
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0)
output = xops.memory_efficient_attention(q, k, v, attn_bias = attn_bias)

return output

def direct_attention(self, q, k, v, padding_mask):
(B, H, L, E) = q.size()
def pytorch_attention(self, q, k, v, lenghts):
(B, L, H, E) = q.size()

# Transpose
q = rearrange(q, 'B L H E -> B H L E')
k = rearrange(k, 'B L H E -> B H L E')
v = rearrange(v, 'B L H E -> B H L E')

# Attention bias
# attn_bias = torch.zeros(L, S, dtype=q.dtype, device = q.device)

attn_bias = None
if lenghts is not None:
attn_bias = create_block_mask(lenghts, q.device)
attn_bias = torch.where(attn_bias, 0, torch.tensor(-10000.0, dtype = q.dtype))

# Calcualte output
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask = attn_bias)
output = output.transpose(1, 2)

return output

def direct_attention(self, q, k, v, lenghts):
(B, L, H, E) = q.size()

# Transpose
q = rearrange(q, 'B L H E -> B H L E')
k = rearrange(k, 'B L H E -> B H L E')
v = rearrange(v, 'B L H E -> B H L E')

# Similarity
scale = 1 / math.sqrt(E)
attn_weight = q @ k.transpose(-2, -1)
attn_weight = attn_weight / math.sqrt(E)
print(attn_weight.shape)
attn_weight = attn_weight * scale

# Attention bias
if lenghts is not None:
attn_bias = create_block_mask(lenghts, q.device)
attn_bias = torch.where(attn_bias, 0, torch.tensor(-10000.0, dtype = q.dtype))
attn_weight += attn_bias

# Softmax
# attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)

# Dropout
attn_weight = torch.dropout(attn_weight, self.dropout_p, train=True)

# Caluclate output
output = attn_weight @ v

return output
return output.transpose(1, 2)


def create_block_mask(lengths, device):
L = sum(lengths)
mask = torch.zeros(L, L, dtype = torch.bool, device = device)
for i in range(len(lengths)):
mask[sum(lengths[:i]):sum(lengths[:i + 1]), sum(lengths[:i]):sum(lengths[:i + 1])] = 1
return mask
4 changes: 0 additions & 4 deletions supervoice_valle/model_nar.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ def forward(self, *, condition_text, condition_audio, audio, codec, loss = False
for b in range(B):
x.append(torch.cat([x_t[b], x_a[b], x_ci[b]]))
x, m = list_to_tensors(x)
# print(m)
# m.expand(-1, heads, q_len, -1)
# m = m.unsqueeze(-1).unsqueeze(-1)
# m = m.contiguous()

#
# Transform
Expand Down
19 changes: 12 additions & 7 deletions supervoice_valle/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def __init__(self, n_heads, n_dim, n_dim_head, n_dim_ffn, att_dropout, ffn_dropo
torch.nn.init.normal_(self.attention_output.weight, mean=0.0, std=0.02)
torch.nn.init.zeros_(self.attention_output.bias)

# Attention dropout
# self.attention_output_dropout = nn.Dropout(dropout)

# MLP part
self.mlp_ln = RMSNorm(n_dim)

Expand Down Expand Up @@ -127,20 +124,28 @@ def forward(self, x, mask = None):
# Calculation Q/K/V for each head
q, k, v = self.attention(y).chunk(3, dim = -1)

with record_function("attention:run"):
# Reshape for head-first attention
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.n_heads), (q, k, v))

# Prepare mask
# mask = None
# if padding_mask is not None:
# mask = rearrange(mask, "b j -> b 1 1 j")
# mask = mask.expand(-1, self.d_heads, q_len, -1)

# Run through attention
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.att_dropout if self.training else 0.0, attn_mask = mask)
y = rearrange(y, 'b h n d -> b n (h d)')

with record_function("attention:post"):
# Reshape back
y = rearrange(y, 'b h n d -> b n (h d)')

# Output
y = self.attention_output(y)
# y = self.attention_output_dropout(y)

# Residual
y = residual + y
residual = y

with record_function("attention:post-post"):
# MLP
y = self.mlp_ln(y)
Expand Down
Loading

0 comments on commit 6b9e1b6

Please sign in to comment.