Skip to content

Commit 7fa7295

Browse files
committed
revert modified files when rebasing
Signed-off-by: Xuesong Yang <[email protected]>
1 parent f9bf5b6 commit 7fa7295

File tree

6 files changed

+20
-371
lines changed

6 files changed

+20
-371
lines changed

examples/tts/conf/fastpitch_ssl.yaml

-7
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,12 @@ use_unique_tokens: true
3333
speaker_conditioning_type: per_sample
3434
segment_speaker_embedding: true
3535
ssl_downsampling_factor: 4 # How many mel-spectrogram frames map to one content embedding in the SSL model
36-
content_aug_types: []
3736

3837
model:
3938
ssl_model_ckpt_path: ${ssl_model_ckpt_path}
4039
ssl_downsampling_factor: ${ssl_downsampling_factor}
4140
use_encoder: true
4241
use_duration_predictor: ${use_unique_tokens}
43-
emb_similarity_threshold: 1.0 # Group content embeddings at consecutive timesteps together if their cosine similarity is above this threshold. 1.0 means no grouping.
44-
4542
pitch_conditioning: true
4643
pitch_loss_scale: 1.0
4744
learn_alignment: true
@@ -66,8 +63,6 @@ model:
6663
speaker_emb_indim: 256
6764
content_emb_outdim: 192
6865
speaker_emb_outdim: 192
69-
70-
content_aug_types: ${content_aug_types}
7166

7267
train_ds:
7368
dataset:
@@ -85,8 +80,6 @@ model:
8580
pad_multiple: 1024
8681
speaker_conditioning_type: ${speaker_conditioning_type}
8782
sup_data_dir: ${sup_data_dir}
88-
content_aug_types: ${content_aug_types}
89-
emb_similarity_threshold: ${model.emb_similarity_threshold}
9083

9184
dataloader_params:
9285
drop_last: false

examples/tts/fastpitch_ssl.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,7 @@ def main(cfg):
2626
exp_manager(trainer, cfg.get("exp_manager", None))
2727
vocoder = hifigan.HifiGanModel.load_from_checkpoint(cfg.hifi_ckpt_path).cpu()
2828
vocoder.eval()
29-
ssl_model = None
30-
if cfg.get("ssl_model_ckpt_path", None):
31-
ssl_model = nemo_asr.models.ssl_models.SpeechEncDecSelfSupervisedModel.load_from_checkpoint(
32-
cfg.ssl_model_ckpt_path
33-
)
34-
ssl_model.eval()
35-
36-
model = fastpitch_ssl.FastPitchModel_SSL(cfg=cfg.model, trainer=trainer, vocoder=vocoder, ssl_model=ssl_model)
29+
model = fastpitch_ssl.FastPitchModel_SSL(cfg=cfg.model, trainer=trainer, vocoder=vocoder)
3730
if cfg.get("finetune", False):
3831
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
3932
lr_logger = pl.callbacks.LearningRateMonitor()

nemo/collections/tts/data/dataset.py

+3-117
Original file line numberDiff line numberDiff line change
@@ -1216,9 +1216,6 @@ def __init__(
12161216
sup_data_dir: Optional[Union[str, Path]] = None,
12171217
speaker_stats_pitch_fp: Optional[Union[str, Path]] = None,
12181218
speaker_conditioning_type: Optional[str] = "per_sample", # per_sample, mean, interpolate,
1219-
content_aug_types: Optional[List[str]] = [],
1220-
alternate_speaker_conditioning: Optional[str] = "random",
1221-
emb_similarity_threshold: Optional[float] = 1.0, # Set to 1.0 to disable grouping
12221219
):
12231220
"""Dataset used for training FastPitchModel_SSL model.
12241221
Requires supplementary data created using scripts/ssl_tts/make_supdata.py
@@ -1313,7 +1310,6 @@ def __init__(
13131310
if sup_data_dir is None:
13141311
sup_data_dir = os.path.join(self.base_data_dir, "sup_data")
13151312
self.sup_data_dir = sup_data_dir
1316-
self.content_aug_types = content_aug_types
13171313

13181314
if self.pitch_normalization == "speaker_wise":
13191315
self.speaker_stats = {}
@@ -1331,10 +1327,6 @@ def __init__(
13311327
for key in speaker_stats_raw:
13321328
self.speaker_stats[int(key)] = speaker_stats_raw[key]
13331329

1334-
self.alternate_speaker_conditioning = alternate_speaker_conditioning
1335-
self.emb_similarity_threshold = emb_similarity_threshold
1336-
self.compute_mean_speaker_embeddings()
1337-
13381330
def _get_wav_from_filepath(self, audio_filepath):
13391331
features = AudioSegment.segment_from_file(
13401332
audio_filepath,
@@ -1355,75 +1347,6 @@ def _get_wav_from_filepath(self, audio_filepath):
13551347

13561348
return audio, audio_length
13571349

1358-
def group_content_embeddings(self, content_embedding, aug_embeddings, duration):
1359-
# content_embedding: (256, n_timesteps)
1360-
grouped_content_embeddings = [content_embedding[:, 0]]
1361-
grouped_durations = [duration[0]]
1362-
grouped_aug_embeddings = {key: [aug_embeddings[key][:, 0]] for key in aug_embeddings}
1363-
group_size = 1
1364-
for _tidx in range(1, content_embedding.shape[1]):
1365-
prev_embedding = grouped_content_embeddings[-1]
1366-
curr_embedding = content_embedding[:, _tidx]
1367-
emb_similarity = torch.cosine_similarity(prev_embedding, curr_embedding, dim=0)
1368-
if emb_similarity < self.emb_similarity_threshold:
1369-
grouped_content_embeddings.append(curr_embedding)
1370-
grouped_durations.append(duration[_tidx])
1371-
for key in aug_embeddings:
1372-
grouped_aug_embeddings[key].append(aug_embeddings[key][:, _tidx])
1373-
else:
1374-
# group with previous embedding
1375-
grouped_content_embeddings[-1] = (grouped_content_embeddings[-1] * group_size + curr_embedding) / (
1376-
group_size + 1
1377-
)
1378-
grouped_durations[-1] += duration[_tidx]
1379-
for key in aug_embeddings:
1380-
grouped_aug_embeddings[key][-1] = (
1381-
grouped_aug_embeddings[key][-1] * group_size + aug_embeddings[key][:, _tidx]
1382-
) / (group_size + 1)
1383-
group_size += 1
1384-
1385-
grouped_content_embeddings = torch.stack(grouped_content_embeddings, dim=1)
1386-
grouped_durations = torch.stack(grouped_durations, dim=0)
1387-
grouped_aug_embeddings = {
1388-
key: torch.stack(grouped_aug_embeddings[key], dim=1) for key in grouped_aug_embeddings
1389-
}
1390-
1391-
return grouped_content_embeddings, grouped_aug_embeddings, grouped_durations
1392-
1393-
def compute_mean_speaker_embeddings(self, n_embeddings_per_speaker=50):
1394-
print("computing mean speaker embeddings...")
1395-
mean_speaker_embeddings = {}
1396-
speaker_counts = {}
1397-
for idx in range(len(self.data)):
1398-
if idx % 1000 == 0:
1399-
print("processed", idx, len(self.data), "files")
1400-
sample = self.data[idx]
1401-
speaker = sample["speaker"]
1402-
if speaker in speaker_counts and speaker_counts[speaker] >= n_embeddings_per_speaker:
1403-
continue
1404-
1405-
rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("")
1406-
rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_")
1407-
speaker_emb_fn = f"speaker_embedding_{rel_audio_path_as_text_id}.pt"
1408-
speaker_emb_fp = os.path.join(self.sup_data_dir, speaker_emb_fn)
1409-
if os.path.exists(speaker_emb_fp):
1410-
embedding = torch.load(speaker_emb_fp)
1411-
if speaker not in mean_speaker_embeddings:
1412-
print("adding speaker", len(mean_speaker_embeddings), speaker, speaker_emb_fp)
1413-
mean_speaker_embeddings[speaker] = embedding
1414-
speaker_counts[speaker] = 1
1415-
else:
1416-
mean_speaker_embeddings[speaker] += embedding
1417-
speaker_counts[speaker] += 1
1418-
1419-
for speaker in mean_speaker_embeddings:
1420-
mean_speaker_embeddings[speaker] /= speaker_counts[speaker]
1421-
l2_norm = torch.norm(mean_speaker_embeddings[speaker], p=2)
1422-
mean_speaker_embeddings[speaker] /= l2_norm
1423-
1424-
print("mean speaker embeddings computed")
1425-
self.mean_speaker_embeddings = mean_speaker_embeddings
1426-
14271350
def get_ssl_features(self, wav_text_id):
14281351
content_emb_fn = f"{self.ssl_content_emb_type}_content_embedding_{wav_text_id}.pt"
14291352
speaker_emb_fn = f"speaker_embedding_{wav_text_id}.pt"
@@ -1446,18 +1369,6 @@ def get_ssl_features(self, wav_text_id):
14461369
f"Speaker embedding file {speaker_emb_fp} does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training."
14471370
)
14481371

1449-
aug_embeddings = {}
1450-
if len(self.content_aug_types) > 0:
1451-
for aug_type in self.content_aug_types:
1452-
aug_emb_fn = f"{self.ssl_content_emb_type}_{aug_type}_content_embedding_{wav_text_id}.pt"
1453-
aug_emb_fp = os.path.join(self.sup_data_dir, aug_emb_fn)
1454-
if os.path.exists(aug_emb_fp):
1455-
aug_embeddings[aug_type] = torch.load(aug_emb_fp)
1456-
else:
1457-
raise ValueError(
1458-
f"Augmented content embedding file {aug_emb_fp} does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training."
1459-
)
1460-
14611372
if os.path.exists(duration_fp):
14621373
duration = torch.load(duration_fp)
14631374
else:
@@ -1467,7 +1378,7 @@ def get_ssl_features(self, wav_text_id):
14671378

14681379
encoded_len = torch.tensor(content_embedding.shape[1]).long()
14691380

1470-
return content_embedding, speaker_embedding, encoded_len, duration, aug_embeddings
1381+
return content_embedding, speaker_embedding, encoded_len, duration
14711382

14721383
def get_pitch_contour(self, wav_text_id):
14731384
pitch_contour_fn = f"pitch_contour_{wav_text_id}.pt"
@@ -1531,14 +1442,6 @@ def pad_collate_fn(self, batch):
15311442
duration_padded = torch.nn.functional.pad(duration, (0, max_encoded_len - duration.size(0)), value=0.0)
15321443
durations_padded.append(duration_padded)
15331444

1534-
other_content_embedding_keys = [k for k in final_batch if k.startswith("content_embedding_")]
1535-
for key in other_content_embedding_keys:
1536-
other_content_embeddings_padded = []
1537-
for encoded in final_batch[key]:
1538-
encoded_padded = torch.nn.functional.pad(encoded, (0, max_encoded_len - encoded.size(1)), value=0)
1539-
other_content_embeddings_padded.append(encoded_padded)
1540-
final_batch[key] = other_content_embeddings_padded
1541-
15421445
final_batch["audio"] = audios_padded
15431446
final_batch["mel_spectrogram"] = mels_padded
15441447
final_batch["pitch_contour"] = pitch_contours_padded
@@ -1548,7 +1451,7 @@ def pad_collate_fn(self, batch):
15481451
for key in final_batch:
15491452
final_batch[key] = torch.stack(final_batch[key])
15501453

1551-
return dict(final_batch)
1454+
return final_batch
15521455

15531456
def __getitem__(self, index):
15541457
sample = self.data[index]
@@ -1563,9 +1466,7 @@ def __getitem__(self, index):
15631466
if self.pitch_conditioning:
15641467
pitch_contour = self.get_pitch_contour(rel_audio_path_as_text_id)
15651468

1566-
content_embedding, speaker_embedding, encoded_len, duration, aug_embeddings = self.get_ssl_features(
1567-
rel_audio_path_as_text_id
1568-
)
1469+
content_embedding, speaker_embedding, encoded_len, duration = self.get_ssl_features(rel_audio_path_as_text_id)
15691470

15701471
if self.speaker_conditioning_type == "mean":
15711472
assert sample["speaker"] in self.mean_speaker_embeddings, "{} not in speaker emb".format(sample['speaker'])
@@ -1586,17 +1487,6 @@ def __getitem__(self, index):
15861487
mel_spectrogram = self.get_mel_spectrogram(rel_audio_path_as_text_id)
15871488
mel_len = torch.tensor(mel_spectrogram.shape[1]).long()
15881489

1589-
alternate_speakers = [spk for spk in self.mean_speaker_embeddings if spk != sample["speaker"]]
1590-
if len(alternate_speakers) == 0:
1591-
alternate_speaker = sample["speaker"]
1592-
else:
1593-
if self.alternate_speaker_conditioning == "random":
1594-
alternate_speaker = random.choice(alternate_speakers)
1595-
elif self.alternate_speaker_conditioning == "fixed":
1596-
alternate_speaker = min(alternate_speakers)
1597-
1598-
alternate_speaker_embedding = self.mean_speaker_embeddings[alternate_speaker]
1599-
16001490
if pitch_contour is not None:
16011491
if self.pitch_normalization in ["speaker_wise", "global"]:
16021492
mean, std = self.pitch_mean, self.pitch_std
@@ -1625,7 +1515,6 @@ def __getitem__(self, index):
16251515
'audio_len': audio_length,
16261516
'content_embedding': content_embedding,
16271517
'speaker_embedding': speaker_embedding,
1628-
'alternate_speaker_embedding': alternate_speaker_embedding,
16291518
'encoded_len': encoded_len,
16301519
'pitch_contour': pitch_contour,
16311520
'speaker': speaker,
@@ -1635,9 +1524,6 @@ def __getitem__(self, index):
16351524
'duration': duration,
16361525
}
16371526

1638-
for aug_type in aug_embeddings:
1639-
item["content_embedding_{}".format(aug_type)] = aug_embeddings[aug_type]
1640-
16411527
return item
16421528

16431529
def __len__(self):

0 commit comments

Comments
 (0)