-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
321 additions
and
122,744 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ | |
.env | ||
__pycache__ | ||
/wandb | ||
.ipynb_checkpoints | ||
.ipynb_checkpoints | ||
train_tokenizer_text.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "9145ba44-c6d3-4ac1-874e-f54e110c15cf", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from supervoice_valle import Attend\n", | ||
"from torch.nn.attention import SDPBackend, sdpa_kernel\n", | ||
"import time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "c585a532-bf54-4abf-b306-5d820151acd5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"a_pt = Attend(engine = \"torch\", heads = 32).to(\"cuda\").eval()\n", | ||
"a_nt = Attend(engine = \"direct\", heads = 32).to(\"cuda\").eval()\n", | ||
"a_xt = Attend(engine = \"xformers\", heads = 32).to(\"cuda\").eval()\n", | ||
"a_ft = Attend(engine = \"flash\", heads = 32).to(\"cuda\").eval()\n", | ||
"attentions = [a_pt, a_nt, a_xt, a_ft]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "bfdeb604-64a0-4f94-9127-9700ba09c9ac", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Without padding\n", | ||
"torch 0.0\n", | ||
"direct 0.00048828125\n", | ||
"xformers 0.0\n", | ||
"flash 0.0\n", | ||
"With padding\n", | ||
"torch 0.0\n", | ||
"direct 0.00048828125\n", | ||
"xformers 0.000244140625\n", | ||
"flash 0.000244140625\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Source\n", | ||
"query = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n", | ||
"key = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n", | ||
"value = torch.rand(1, 32, 32, 16, dtype=torch.float16, device=\"cuda\")\n", | ||
"lengths = [4, 8, 8, 12]\n", | ||
"\n", | ||
"print(\"Without padding\")\n", | ||
"source = a_pt(query, key, value)\n", | ||
"for a in attentions:\n", | ||
" dest = a(query, key, value)\n", | ||
" print(a.engine, (dest - source).abs().max().item())\n", | ||
"\n", | ||
"print(\"With padding\")\n", | ||
"source = a_pt(query, key, value, lenghts = lengths)\n", | ||
"for a in attentions:\n", | ||
" dest = a(query, key, value, lenghts = lengths)\n", | ||
" print(a.engine, (dest - source).abs().max().item())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "039f7197-074f-4776-a5e7-8c2b5eaabe4b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Without padding\n", | ||
"torch 0.4039938449859619\n", | ||
"direct 0.9630444049835205\n", | ||
"xformers 0.9124343395233154\n", | ||
"flash 0.3562278747558594\n", | ||
"With padding\n", | ||
"torch 1.6574337482452393\n", | ||
"direct 2.504969835281372\n", | ||
"xformers 1.5768730640411377\n", | ||
"flash 1.4189743995666504\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Benchmark\n", | ||
"print(\"Without padding\")\n", | ||
"for a in attentions:\n", | ||
" start = time.time()\n", | ||
" for i in range(10000):\n", | ||
" a(query, key, value)\n", | ||
" print(a.engine, time.time() - start)\n", | ||
"\n", | ||
"print(\"With padding\")\n", | ||
"for a in attentions:\n", | ||
" start = time.time()\n", | ||
" for i in range(10000):\n", | ||
" a(query, key, value, lenghts = lengths)\n", | ||
" print(a.engine, time.time() - start)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "05aa22c4-38ca-48e2-8402-f7de96c9be17", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,128 @@ | ||
import torch | ||
from torch import nn | ||
import math | ||
from einops import rearrange | ||
import xformers.ops as xops | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
|
||
class Attend(nn.Module): | ||
def __init__(self, *, heads, dropout = 0., engine = "direct"): | ||
class Attend(torch.nn.Module): | ||
def __init__(self, *, heads, engine = "direct"): | ||
super().__init__() | ||
self.heads = heads | ||
self.dropout = nn.Dropout(dropout) | ||
self.dropout_p = dropout | ||
self.engine = engine | ||
|
||
def forward(self, q, k, v, padding_mask = None): | ||
def forward(self, q, k, v, lenghts = None): | ||
|
||
# Check argument shapes | ||
assert q.dim() == 4 | ||
assert k.dim() == 4 | ||
assert v.dim() == 4 | ||
assert q.size(0) == k.size(0) == v.size(0), "Batch size mismatch" | ||
assert q.size(1) == k.size(1) == v.size(1) == self.heads, "Heads length mismatch" | ||
assert q.size(2) == k.size(2) == v.size(2), "Sequence length mismatch" | ||
assert q.size(1) == k.size(1) == v.size(1), "Sequence length mismatch" | ||
assert q.size(2) == k.size(2) == v.size(2) == self.heads, "Heads length mismatch" | ||
assert q.size(3) == k.size(3) == v.size(3), "Embeddings dimensions mismatch" | ||
if padding_mask is not None: | ||
assert padding_mask.dim() == 2 | ||
assert padding_mask.size(0) == q.size(0) | ||
assert padding_mask.size(2) == q.size(2) | ||
if lenghts is not None: | ||
assert sum(lenghts) == q.size(1) | ||
|
||
# Check padding mask | ||
if self.engine == "direct": | ||
return self.direct_attention(q, k, v, padding_mask) | ||
return self.direct_attention(q, k, v, lenghts) | ||
elif self.engine == "torch": | ||
return self.pytorch_attention(q, k, v, padding_mask) | ||
return self.pytorch_attention(q, k, v, lenghts) | ||
elif self.engine == "xformers": | ||
return self.xformers_attention(q, k, v, lenghts) | ||
elif self.engine == "flash": | ||
return self.flash_attention(q, k, v, lenghts) | ||
else: | ||
raise ValueError("Invalid engine") | ||
|
||
def pytorch_attention(self, q, k, v, padding_mask): | ||
(B, H, L, E) = q.size() | ||
def flash_attention(self, q, k, v, lengths): | ||
(B, L, H, E) = q.size() | ||
|
||
# With lengths | ||
if lengths is not None: | ||
|
||
# Max lengths | ||
max_len = torch.tensor(max(lengths), dtype = q.dtype, device = q.device) | ||
|
||
# Seq lens | ||
seqlens = [0] | ||
last = 0 | ||
for l in lengths: | ||
last += l | ||
seqlens.append(last) | ||
seqlens = torch.tensor(seqlens, dtype = torch.int32, device = q.device) | ||
|
||
return flash_attn_varlen_func(q.squeeze(0), k.squeeze(0), v.squeeze(0), seqlens, seqlens, max_len, max_len).unsqueeze(0) | ||
|
||
# Non length | ||
return flash_attn_func(q, k, v) | ||
|
||
|
||
def xformers_attention(self, q, k, v, lenghts): | ||
(B, L, H, E) = q.size() | ||
|
||
# Attention bias | ||
attn_bias = None | ||
if lenghts is not None: | ||
attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(lenghts) | ||
|
||
# Calcualte output | ||
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0) | ||
output = xops.memory_efficient_attention(q, k, v, attn_bias = attn_bias) | ||
|
||
return output | ||
|
||
def direct_attention(self, q, k, v, padding_mask): | ||
(B, H, L, E) = q.size() | ||
def pytorch_attention(self, q, k, v, lenghts): | ||
(B, L, H, E) = q.size() | ||
|
||
# Transpose | ||
q = rearrange(q, 'B L H E -> B H L E') | ||
k = rearrange(k, 'B L H E -> B H L E') | ||
v = rearrange(v, 'B L H E -> B H L E') | ||
|
||
# Attention bias | ||
# attn_bias = torch.zeros(L, S, dtype=q.dtype, device = q.device) | ||
|
||
attn_bias = None | ||
if lenghts is not None: | ||
attn_bias = create_block_mask(lenghts, q.device) | ||
attn_bias = torch.where(attn_bias, 0, torch.tensor(-10000.0, dtype = q.dtype)) | ||
|
||
# Calcualte output | ||
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask = attn_bias) | ||
output = output.transpose(1, 2) | ||
|
||
return output | ||
|
||
def direct_attention(self, q, k, v, lenghts): | ||
(B, L, H, E) = q.size() | ||
|
||
# Transpose | ||
q = rearrange(q, 'B L H E -> B H L E') | ||
k = rearrange(k, 'B L H E -> B H L E') | ||
v = rearrange(v, 'B L H E -> B H L E') | ||
|
||
# Similarity | ||
scale = 1 / math.sqrt(E) | ||
attn_weight = q @ k.transpose(-2, -1) | ||
attn_weight = attn_weight / math.sqrt(E) | ||
print(attn_weight.shape) | ||
attn_weight = attn_weight * scale | ||
|
||
# Attention bias | ||
if lenghts is not None: | ||
attn_bias = create_block_mask(lenghts, q.device) | ||
attn_bias = torch.where(attn_bias, 0, torch.tensor(-10000.0, dtype = q.dtype)) | ||
attn_weight += attn_bias | ||
|
||
# Softmax | ||
# attn_weight += attn_bias | ||
attn_weight = torch.softmax(attn_weight, dim=-1) | ||
|
||
# Dropout | ||
attn_weight = torch.dropout(attn_weight, self.dropout_p, train=True) | ||
|
||
# Caluclate output | ||
output = attn_weight @ v | ||
|
||
return output | ||
return output.transpose(1, 2) | ||
|
||
|
||
def create_block_mask(lengths, device): | ||
L = sum(lengths) | ||
mask = torch.zeros(L, L, dtype = torch.bool, device = device) | ||
for i in range(len(lengths)): | ||
mask[sum(lengths[:i]):sum(lengths[:i + 1]), sum(lengths[:i]):sum(lengths[:i + 1])] = 1 | ||
return mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.