Skip to content

Commit

Permalink
wip: working on benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Jul 2, 2024
1 parent 83693bd commit cda1e0e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 91 deletions.
17 changes: 10 additions & 7 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def main():
model = torch.compile(model, mode="reduce-overhead")
model, optim = accelerator.prepare(model, optim)

# torch.cuda.set_sync_debug_mode("error")

# Train step
def train_step():
model.train()
Expand Down Expand Up @@ -140,10 +142,11 @@ def train_step():
audio_split = random.randint(min_duration, max_duration)
else:
audio_split = max_duration
audio_full.append(a[:, :audio_split].to(device, non_blocking=True))
audio_partial.append(a[:, audio_split:].to(device, non_blocking=True))
audio_codecs.append(random.randint(1, 7))
texts.append(t.to(device, non_blocking=True))
with record_function("load_batch:append"):
audio_full.append(a[:, :audio_split].to(device, non_blocking=True))
audio_partial.append(a[:, audio_split:].to(device, non_blocking=True))
audio_codecs.append(random.randint(1, 7))
texts.append(t.to(device, non_blocking=True))

# Forward
with record_function("forward"):
Expand All @@ -155,9 +158,9 @@ def train_step():
loss = True
)

# Check if loss is NaN
if torch.isnan(loss):
raise ValueError("Loss is NaN")
# # Check if loss is NaN
# if torch.isnan(loss):
# raise ValueError("Loss is NaN")

# Backprop
with record_function("backward"):
Expand Down
172 changes: 88 additions & 84 deletions supervoice_valle/model_nar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .transformer import Transformer
from .tensors import sinusoids, list_to_tensors
from torch.nn.utils.rnn import pad_sequence
from torch.profiler import record_function

class SupervoceNARModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -46,96 +47,99 @@ def __init__(self):
self.prediction = torch.nn.Linear(self.n_dim, 1024, bias=False)

def forward(self, *, condition_text, condition_audio, audio, codec, loss = False):

# Check inputs
device = condition_text[0].device
B = len(condition_text)
assert len(condition_audio) == B
assert len(audio) == B
assert len(codec) == B

# Check shapes
for b in range(B):

# Check condition shape
assert condition_text[b].dim() == 1, f"Unexpected shape: {condition_text[b].shape}"
assert condition_audio[b].dim() == 2, f"Unexpected shape: {condition_audio[b].shape}"
assert condition_audio[b].shape[0] == 8, f"Unexpected shape: {condition_audio[b].shape}"

# Check codec value
assert codec[b] >= 1 and codec[b] <= 7, f"Unexpected codec value: {codec[b]}"

# Check audio shape
assert audio[b].dim() == 2, f"Unexpected shape: {audio[b].shape}"
assert audio[b].shape[0] >= codec[b], f"Unexpected shape: {audio[b].shape}"

#
# Prepare EOS
#

eos = self.eos_embedding(torch.tensor([0], device = device))
# Prepare
with record_function("prepare"):
# Check inputs
device = condition_text[0].device
B = len(condition_text)
assert len(condition_audio) == B
assert len(audio) == B
assert len(codec) == B

# Check shapes
for b in range(B):

# Check condition shape
assert condition_text[b].dim() == 1, f"Unexpected shape: {condition_text[b].shape}"
assert condition_audio[b].dim() == 2, f"Unexpected shape: {condition_audio[b].shape}"
assert condition_audio[b].shape[0] == 8, f"Unexpected shape: {condition_audio[b].shape}"

# Check codec value
assert codec[b] >= 1 and codec[b] <= 7, f"Unexpected codec value: {codec[b]}"

# Check audio shape
assert audio[b].dim() == 2, f"Unexpected shape: {audio[b].shape}"
assert audio[b].shape[0] >= codec[b], f"Unexpected shape: {audio[b].shape}"

#
# Prepare EOS
#
with record_function("prepare:eos"):
eos = self.eos_embedding(torch.tensor([0]).to(device, non_blocking=True))

#
# Text embedding
#
l_t = []
x_t = []
for b in range(B):
t = torch.cat([self.text_embedding(condition_text[b]), eos])
# t = t + self.positional_embedding[:t.shape[0]]
t = t + self.positional_embedding_text(torch.arange(t.shape[0], device = t.device))
x_t.append(t)
l_t.append(t.shape[0])

#
# Audio embedding
#

x_a = []
l_c = []
l_a = []
for b in range(B):

# Condition embedding
t_c = self.audio_embedding(condition_audio[b][0])
for i in range(1, condition_audio[b].shape[0]):
t_c = t_c + self.audio_embedding(condition_audio[b][i])

#
# Text embedding
#
with record_function("prepare:text"):
l_t = []
x_t = []
for b in range(B):
t = torch.cat([self.text_embedding(condition_text[b]), eos])
# t = t + self.positional_embedding[:t.shape[0]]
t = t + self.positional_embedding_text(torch.arange(t.shape[0]).to(t.device, non_blocking=True))
x_t.append(t)
l_t.append(t.shape[0])

#
# Audio embedding
t_a = self.audio_embedding(audio[b][0])
for i in range(1, codec[b]):
t_a = t_a + self.audio_embedding(audio[b][i])
#

x_a = []
l_c = []
l_a = []
for b in range(B):

# Condition embedding
t_c = self.audio_embedding(condition_audio[b][0])
for i in range(1, condition_audio[b].shape[0]):
t_c = t_c + self.audio_embedding(condition_audio[b][i])

# Audio embedding
t_a = self.audio_embedding(audio[b][0])
for i in range(1, codec[b]):
t_a = t_a + self.audio_embedding(audio[b][i])

# Concatenate all
t = torch.cat([t_c, t_a, eos])

# Positional embedding
# t = t + self.positional_embedding[:t.shape[0]]
t = t + self.positional_embedding_audio(torch.arange(t.shape[0], device = t.device))

# Append
x_a.append(t)
l_c.append(t_c.shape[0])
l_a.append(t_a.shape[0])

#
# Codec embedding
#

x_ci = []
for b in range(B):
t_ci = self.codec_index_embedding(torch.tensor([codec[b] - 1], device = device).long())
x_ci.append(t_ci)
# Concatenate all
t = torch.cat([t_c, t_a, eos])

# Positional embedding
# t = t + self.positional_embedding[:t.shape[0]]
t = t + self.positional_embedding_audio(torch.arange(t.shape[0]).to(t.device, non_blocking=True))

# Append
x_a.append(t)
l_c.append(t_c.shape[0])
l_a.append(t_a.shape[0])

#
# Codec embedding
#

x_ci = []
for b in range(B):
t_ci = self.codec_index_embedding(torch.tensor([codec[b] - 1]).long().to(t.device, non_blocking=True))
x_ci.append(t_ci)

#
# Concatenate all
#
#
# Concatenate all
#

x = []
for b in range(B):
x.append(torch.cat([x_t[b], x_a[b], x_ci[b]]))
x, m = list_to_tensors(x)
m = m.unsqueeze(-1).unsqueeze(-1)
x = []
for b in range(B):
x.append(torch.cat([x_t[b], x_a[b], x_ci[b]]))
x, m = list_to_tensors(x)
m = m.unsqueeze(-1).unsqueeze(-1)

#
# Transform
Expand Down

0 comments on commit cda1e0e

Please sign in to comment.