-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
499 additions
and
40 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from supervoice_valle import Supervoice, Tokenizer | ||
from encodec import EncodecModel | ||
from pathlib import Path | ||
import torch | ||
|
||
# Load tokenizer | ||
tokenizer = Tokenizer("./tokenizer_text.model") | ||
|
||
# Load encodec | ||
encodec_model = EncodecModel.encodec_model_24khz() | ||
encodec_model.set_target_bandwidth(6.0) | ||
|
||
# We don't need a real model for this task | ||
model = Supervoice(None, None, encodec_model, None, tokenizer) | ||
|
||
# Find all wav files in the voices directory | ||
wav_files = list(Path('voices').glob('*.wav')) | ||
wav_files = [f.stem for f in wav_files] | ||
|
||
# Generate voices | ||
for id in wav_files: | ||
print(f"Processing {id}") | ||
with open("./voices/" + id + ".txt", 'r') as f: | ||
text = f.read().strip() | ||
created_voice = model.create_voice(audio = "./voices/" + id + ".wav", text = text) | ||
torch.save(created_voice, f"./voices/{id}.pt") | ||
|
||
# Generate index file | ||
with open("supervoice_valle/voices_gen.py", "w") as f: | ||
f.write(f"available_voices = {wav_files}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
dependencies = ['torch', 'torchaudio'] | ||
|
||
def supervoice(): | ||
|
||
# Imports | ||
import torch | ||
import os | ||
from supervoice_valle import SupervoceNARModel, SupervoceARModel, Tokenizer, Supervoice | ||
|
||
# Load tokenizer | ||
tokenizer = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer_text.model")) | ||
|
||
# Load encodec | ||
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz") | ||
encodec_model = EncodecModel.encodec_model_24khz() | ||
encodec_model.set_target_bandwidth(6.0) | ||
|
||
# Load checkpoints | ||
ar_model = SupervoceARModel() | ||
nar_model = SupervoceNARModel() | ||
checkpoint_ar = torch.hub.load_state_dict_from_url("https://shared.korshakov.com/models/supervoice-valle-ar-600000.pt", map_location="cpu") | ||
checkpoint_nar = torch.hub.load_state_dict_from_url("https://shared.korshakov.com/models/supervoice-valle-nar-600000.pt", map_location="cpu") | ||
ar_model.load_state_dict(checkpoint_ar['model']) | ||
nar_model.load_state_dict(checkpoint_nar['model']) | ||
|
||
# Create model | ||
model = Supervoice(ar_model, nar_model, encodec_model, vocos, tokenizer) | ||
|
||
# Switch to eval mode | ||
model.eval() | ||
|
||
return model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .model_nar import * | ||
from .model_ar import * | ||
from .model import * | ||
from .tokenizer import * | ||
from .transformer import * | ||
from .attention import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torchaudio | ||
from encodec import EncodecModel | ||
from encodec.utils import convert_audio | ||
from .voices_gen import available_voices | ||
import os | ||
|
||
class Supervoice(nn.Module): | ||
def __init__(self, model_ar, model_nar, model_encodec, vocoder, tokenizer): | ||
super(Supervoice, self).__init__() | ||
self.model_ar = model_ar | ||
self.model_nar = model_nar | ||
self.model_encodec = model_encodec | ||
self.tokenizer = tokenizer | ||
self.vocoder = vocoder | ||
|
||
@torch.inference_mode() | ||
def create_voice(self, audio, text): | ||
device = self._device() | ||
|
||
# Load audio | ||
if type(audio) is str: | ||
audio, sr = torchaudio.load(audio) | ||
if sr != 16000: | ||
audio = torchaudio.transforms.Resample(sr, 16000, dtype=audio.dtype)(audio) | ||
audio = audio.squeeze(0) | ||
else: | ||
assert audio.dim() == 2 or audio.dim() == 1, "Audio must be 1D or 2D tensor" | ||
if audio.dim() == 2: | ||
assert audio.size(0) == 1, "Audio must have a single channel" | ||
audio = audio.squeeze(0) | ||
|
||
# Preprocess audio | ||
audio = convert_audio(audio.unsqueeze(0), 16000, self.model_encodec.sample_rate, self.model_encodec.channels) | ||
|
||
# Encode audio | ||
wav = audio.unsqueeze(0) | ||
encoded_frames = self.model_encodec.encode(wav.to(device)) | ||
audio_tokens = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu() | ||
|
||
# Prepare text | ||
text = self._normalize_text(text) | ||
|
||
# Tokenize text | ||
text_tokens = self.tokenizer.encode(text) | ||
|
||
# Return | ||
return { | ||
"audio_tokens": audio_tokens, | ||
"text_tokens": text_tokens, | ||
} | ||
|
||
@torch.inference_mode() | ||
def synthesize(self, voice, text, top_k = None, top_p = 0.2): | ||
device = self._device() | ||
|
||
# Prepare text | ||
text = self._normalize_text(text) | ||
|
||
# Prepare voice | ||
if type(voice) is str: | ||
|
||
# Check if voice is available | ||
if voice not in available_voices: | ||
raise ValueError(f"Voice {voice} is not available") | ||
|
||
# Get the current file directory | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Load voice | ||
voice_file = os.path.join(current_dir, "..", "voices", voice + ".pt") | ||
voice = torch.load(voice_file, map_location = "cpu") | ||
|
||
# Tokenize text | ||
text_tokens = torch.cat([voice["text_tokens"].to(device), self.tokenizer.encode(text).to(device)]) | ||
|
||
# Audio tokens | ||
audio_tokens = voice["audio_tokens"].to(device) | ||
|
||
# AR inference | ||
coarse_tokens = self.inference_ar(text_tokens, audio_tokens[0], top_k = top_k, top_p = top_p) | ||
|
||
# NAR inference | ||
tokens = self.inference_nar(text_tokens, audio_tokens, coarse_tokens) | ||
|
||
# Vocoder | ||
features = self.vocoder.codes_to_features(tokens.to(device)) | ||
bandwidth_id = torch.tensor([2]).to(device) # 6 kbps | ||
return self.vocoder.decode(features, bandwidth_id=bandwidth_id) | ||
|
||
@torch.inference_mode() | ||
def inference_ar(self, text_tokens, audio_tokens, top_k = None, top_p = None): | ||
device = self._device() | ||
|
||
# Run inference | ||
text_tokens = text_tokens.to(device) | ||
output = audio_tokens.to(device) | ||
prev = None | ||
while True: | ||
|
||
# Inference | ||
p = self.model_ar( | ||
text = [text_tokens], | ||
audio = [output] | ||
) | ||
p = p[0][-1] | ||
|
||
# Sample code | ||
code, prev = self._sample_ar(p, top_k = top_k, top_p = top_p, prev = prev) | ||
|
||
# Append code | ||
if (code > 1023) or output.shape[0] > 2000: | ||
break | ||
output = torch.cat([output, torch.tensor([code], device = output.device)]) | ||
|
||
# Cut the audio tokens | ||
output = output[audio_tokens.shape[0]:] | ||
|
||
return output | ||
|
||
@torch.inference_mode() | ||
def inference_nar(self, text_tokens, audio_tokens, coarse_tokens): | ||
device = self._device() | ||
|
||
# Run inference | ||
condition_text = text_tokens.to(device) | ||
condition_audio = audio_tokens.to(device) | ||
predicted = [coarse_tokens.to(device)] | ||
for i in range(1, 8): | ||
|
||
# Inference | ||
p = self.model_nar( | ||
condition_text = [condition_text], | ||
condition_audio = [condition_audio], | ||
audio = [torch.stack(predicted)], | ||
codec = [i] | ||
) | ||
|
||
# Argmax sampling | ||
p = p[0] | ||
p = torch.nn.functional.softmax(p, dim=-1) | ||
p = torch.argmax(p, dim=-1, keepdim=True) | ||
p = p.squeeze(-1) | ||
|
||
# Append | ||
predicted.append(p) | ||
|
||
# Result | ||
return torch.stack(predicted) | ||
|
||
def _device(self): | ||
return next(self.parameters()).device | ||
|
||
def _normalize_text(self, text): | ||
# This method follows the same normalization of the libriheavy dataset | ||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") | ||
text = text.translate(table) | ||
return text.strip() | ||
|
||
def _sample_ar(self, logits, top_k = None, top_p = None, prev = None): | ||
|
||
# Top-k | ||
if top_k is not None: | ||
|
||
# Find all indices which value is less than k-th one | ||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] | ||
|
||
# Assign minus infinity for such values | ||
logits[indices_to_remove] = float('-inf') | ||
|
||
# Top-p | ||
if top_p is not None: | ||
|
||
# Sort logits | ||
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) | ||
|
||
# Calculate cummulative probabilities | ||
cum_sum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
# Remove all indices with cummulative probability more than top_p | ||
sorted_indices_to_remove = cum_sum_probs < top_p | ||
|
||
# Shift the indices to the right to keep also the first token above the threshold | ||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
sorted_indices_to_remove[..., 0] = 0 | ||
|
||
# Assign minus infinity for such values | ||
sorted_logits[sorted_indices_to_remove] = float('-inf') | ||
|
||
# Then reverse the sorting process by mapping back sorted_logits to their original position | ||
logits = torch.gather(sorted_logits, 0, sorted_indices.argsort(-1)) | ||
|
||
# Softmax | ||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
|
||
# Sample | ||
return torch.multinomial(probs, num_samples=1).item(), None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
available_voices = ['voice_3', 'voice_1', 'voice_2'] |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
We had not stopped to study the Indian character. |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
A man approached through the heavy gloom. |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
He was a typical pirate. |
Binary file not shown.
Oops, something went wrong.