Skip to content

Commit

Permalink
feat: add frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Jul 14, 2024
1 parent 481ed55 commit 66dc03e
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 40 deletions.
Binary file added eval/hifitts_0.flac
Binary file not shown.
31 changes: 31 additions & 0 deletions generate_voices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from supervoice_valle import Supervoice, Tokenizer
from encodec import EncodecModel
from pathlib import Path
import torch

# Load tokenizer
tokenizer = Tokenizer("./tokenizer_text.model")

# Load encodec
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model.set_target_bandwidth(6.0)

# We don't need a real model for this task
model = Supervoice(None, None, encodec_model, None, tokenizer)

# Find all wav files in the voices directory
wav_files = list(Path('voices').glob('*.wav'))
wav_files = [f.stem for f in wav_files]

# Generate voices
for id in wav_files:
print(f"Processing {id}")
with open("./voices/" + id + ".txt", 'r') as f:
text = f.read().strip()
created_voice = model.create_voice(audio = "./voices/" + id + ".wav", text = text)
torch.save(created_voice, f"./voices/{id}.pt")

# Generate index file
with open("supervoice_valle/voices_gen.py", "w") as f:
f.write(f"available_voices = {wav_files}")

33 changes: 33 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
dependencies = ['torch', 'torchaudio']

def supervoice():

# Imports
import torch
import os
from supervoice_valle import SupervoceNARModel, SupervoceARModel, Tokenizer, Supervoice

# Load tokenizer
tokenizer = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer_text.model"))

# Load encodec
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model.set_target_bandwidth(6.0)

# Load checkpoints
ar_model = SupervoceARModel()
nar_model = SupervoceNARModel()
checkpoint_ar = torch.hub.load_state_dict_from_url("https://shared.korshakov.com/models/supervoice-valle-ar-600000.pt", map_location="cpu")
checkpoint_nar = torch.hub.load_state_dict_from_url("https://shared.korshakov.com/models/supervoice-valle-nar-600000.pt", map_location="cpu")
ar_model.load_state_dict(checkpoint_ar['model'])
nar_model.load_state_dict(checkpoint_nar['model'])

# Create model
model = Supervoice(ar_model, nar_model, encodec_model, vocos, tokenizer)

# Switch to eval mode
model.eval()

return model

1 change: 1 addition & 0 deletions supervoice_valle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .model_nar import *
from .model_ar import *
from .model import *
from .tokenizer import *
from .transformer import *
from .attention import *
198 changes: 198 additions & 0 deletions supervoice_valle/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import torch
import torch.nn as nn
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from .voices_gen import available_voices
import os

class Supervoice(nn.Module):
def __init__(self, model_ar, model_nar, model_encodec, vocoder, tokenizer):
super(Supervoice, self).__init__()
self.model_ar = model_ar
self.model_nar = model_nar
self.model_encodec = model_encodec
self.tokenizer = tokenizer
self.vocoder = vocoder

@torch.inference_mode()
def create_voice(self, audio, text):
device = self._device()

# Load audio
if type(audio) is str:
audio, sr = torchaudio.load(audio)
if sr != 16000:
audio = torchaudio.transforms.Resample(sr, 16000, dtype=audio.dtype)(audio)
audio = audio.squeeze(0)
else:
assert audio.dim() == 2 or audio.dim() == 1, "Audio must be 1D or 2D tensor"
if audio.dim() == 2:
assert audio.size(0) == 1, "Audio must have a single channel"
audio = audio.squeeze(0)

# Preprocess audio
audio = convert_audio(audio.unsqueeze(0), 16000, self.model_encodec.sample_rate, self.model_encodec.channels)

# Encode audio
wav = audio.unsqueeze(0)
encoded_frames = self.model_encodec.encode(wav.to(device))
audio_tokens = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu()

# Prepare text
text = self._normalize_text(text)

# Tokenize text
text_tokens = self.tokenizer.encode(text)

# Return
return {
"audio_tokens": audio_tokens,
"text_tokens": text_tokens,
}

@torch.inference_mode()
def synthesize(self, voice, text, top_k = None, top_p = 0.2):
device = self._device()

# Prepare text
text = self._normalize_text(text)

# Prepare voice
if type(voice) is str:

# Check if voice is available
if voice not in available_voices:
raise ValueError(f"Voice {voice} is not available")

# Get the current file directory
current_dir = os.path.dirname(os.path.abspath(__file__))

# Load voice
voice_file = os.path.join(current_dir, "..", "voices", voice + ".pt")
voice = torch.load(voice_file, map_location = "cpu")

# Tokenize text
text_tokens = torch.cat([voice["text_tokens"].to(device), self.tokenizer.encode(text).to(device)])

# Audio tokens
audio_tokens = voice["audio_tokens"].to(device)

# AR inference
coarse_tokens = self.inference_ar(text_tokens, audio_tokens[0], top_k = top_k, top_p = top_p)

# NAR inference
tokens = self.inference_nar(text_tokens, audio_tokens, coarse_tokens)

# Vocoder
features = self.vocoder.codes_to_features(tokens.to(device))
bandwidth_id = torch.tensor([2]).to(device) # 6 kbps
return self.vocoder.decode(features, bandwidth_id=bandwidth_id)

@torch.inference_mode()
def inference_ar(self, text_tokens, audio_tokens, top_k = None, top_p = None):
device = self._device()

# Run inference
text_tokens = text_tokens.to(device)
output = audio_tokens.to(device)
prev = None
while True:

# Inference
p = self.model_ar(
text = [text_tokens],
audio = [output]
)
p = p[0][-1]

# Sample code
code, prev = self._sample_ar(p, top_k = top_k, top_p = top_p, prev = prev)

# Append code
if (code > 1023) or output.shape[0] > 2000:
break
output = torch.cat([output, torch.tensor([code], device = output.device)])

# Cut the audio tokens
output = output[audio_tokens.shape[0]:]

return output

@torch.inference_mode()
def inference_nar(self, text_tokens, audio_tokens, coarse_tokens):
device = self._device()

# Run inference
condition_text = text_tokens.to(device)
condition_audio = audio_tokens.to(device)
predicted = [coarse_tokens.to(device)]
for i in range(1, 8):

# Inference
p = self.model_nar(
condition_text = [condition_text],
condition_audio = [condition_audio],
audio = [torch.stack(predicted)],
codec = [i]
)

# Argmax sampling
p = p[0]
p = torch.nn.functional.softmax(p, dim=-1)
p = torch.argmax(p, dim=-1, keepdim=True)
p = p.squeeze(-1)

# Append
predicted.append(p)

# Result
return torch.stack(predicted)

def _device(self):
return next(self.parameters()).device

def _normalize_text(self, text):
# This method follows the same normalization of the libriheavy dataset
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
return text.strip()

def _sample_ar(self, logits, top_k = None, top_p = None, prev = None):

# Top-k
if top_k is not None:

# Find all indices which value is less than k-th one
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]

# Assign minus infinity for such values
logits[indices_to_remove] = float('-inf')

# Top-p
if top_p is not None:

# Sort logits
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)

# Calculate cummulative probabilities
cum_sum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

# Remove all indices with cummulative probability more than top_p
sorted_indices_to_remove = cum_sum_probs < top_p

# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# Assign minus infinity for such values
sorted_logits[sorted_indices_to_remove] = float('-inf')

# Then reverse the sorting process by mapping back sorted_logits to their original position
logits = torch.gather(sorted_logits, 0, sorted_indices.argsort(-1))

# Softmax
probs = torch.nn.functional.softmax(logits, dim=-1)

# Sample
return torch.multinomial(probs, num_samples=1).item(), None
1 change: 1 addition & 0 deletions supervoice_valle/voices_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
available_voices = ['voice_3', 'voice_1', 'voice_2']
Binary file added voices/voice_1.pt
Binary file not shown.
1 change: 1 addition & 0 deletions voices/voice_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
We had not stopped to study the Indian character.
Binary file added voices/voice_1.wav
Binary file not shown.
Binary file added voices/voice_2.pt
Binary file not shown.
1 change: 1 addition & 0 deletions voices/voice_2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
A man approached through the heavy gloom.
Binary file added voices/voice_2.wav
Binary file not shown.
Binary file added voices/voice_3.pt
Binary file not shown.
1 change: 1 addition & 0 deletions voices/voice_3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
He was a typical pirate.
Binary file added voices/voice_3.wav
Binary file not shown.
Loading

0 comments on commit 66dc03e

Please sign in to comment.