@@ -1216,9 +1216,6 @@ def __init__(
1216
1216
sup_data_dir : Optional [Union [str , Path ]] = None ,
1217
1217
speaker_stats_pitch_fp : Optional [Union [str , Path ]] = None ,
1218
1218
speaker_conditioning_type : Optional [str ] = "per_sample" , # per_sample, mean, interpolate,
1219
- content_aug_types : Optional [List [str ]] = [],
1220
- alternate_speaker_conditioning : Optional [str ] = "random" ,
1221
- emb_similarity_threshold : Optional [float ] = 1.0 , # Set to 1.0 to disable grouping
1222
1219
):
1223
1220
"""Dataset used for training FastPitchModel_SSL model.
1224
1221
Requires supplementary data created using scripts/ssl_tts/make_supdata.py
@@ -1313,7 +1310,6 @@ def __init__(
1313
1310
if sup_data_dir is None :
1314
1311
sup_data_dir = os .path .join (self .base_data_dir , "sup_data" )
1315
1312
self .sup_data_dir = sup_data_dir
1316
- self .content_aug_types = content_aug_types
1317
1313
1318
1314
if self .pitch_normalization == "speaker_wise" :
1319
1315
self .speaker_stats = {}
@@ -1331,10 +1327,6 @@ def __init__(
1331
1327
for key in speaker_stats_raw :
1332
1328
self .speaker_stats [int (key )] = speaker_stats_raw [key ]
1333
1329
1334
- self .alternate_speaker_conditioning = alternate_speaker_conditioning
1335
- self .emb_similarity_threshold = emb_similarity_threshold
1336
- self .compute_mean_speaker_embeddings ()
1337
-
1338
1330
def _get_wav_from_filepath (self , audio_filepath ):
1339
1331
features = AudioSegment .segment_from_file (
1340
1332
audio_filepath ,
@@ -1355,75 +1347,6 @@ def _get_wav_from_filepath(self, audio_filepath):
1355
1347
1356
1348
return audio , audio_length
1357
1349
1358
- def group_content_embeddings (self , content_embedding , aug_embeddings , duration ):
1359
- # content_embedding: (256, n_timesteps)
1360
- grouped_content_embeddings = [content_embedding [:, 0 ]]
1361
- grouped_durations = [duration [0 ]]
1362
- grouped_aug_embeddings = {key : [aug_embeddings [key ][:, 0 ]] for key in aug_embeddings }
1363
- group_size = 1
1364
- for _tidx in range (1 , content_embedding .shape [1 ]):
1365
- prev_embedding = grouped_content_embeddings [- 1 ]
1366
- curr_embedding = content_embedding [:, _tidx ]
1367
- emb_similarity = torch .cosine_similarity (prev_embedding , curr_embedding , dim = 0 )
1368
- if emb_similarity < self .emb_similarity_threshold :
1369
- grouped_content_embeddings .append (curr_embedding )
1370
- grouped_durations .append (duration [_tidx ])
1371
- for key in aug_embeddings :
1372
- grouped_aug_embeddings [key ].append (aug_embeddings [key ][:, _tidx ])
1373
- else :
1374
- # group with previous embedding
1375
- grouped_content_embeddings [- 1 ] = (grouped_content_embeddings [- 1 ] * group_size + curr_embedding ) / (
1376
- group_size + 1
1377
- )
1378
- grouped_durations [- 1 ] += duration [_tidx ]
1379
- for key in aug_embeddings :
1380
- grouped_aug_embeddings [key ][- 1 ] = (
1381
- grouped_aug_embeddings [key ][- 1 ] * group_size + aug_embeddings [key ][:, _tidx ]
1382
- ) / (group_size + 1 )
1383
- group_size += 1
1384
-
1385
- grouped_content_embeddings = torch .stack (grouped_content_embeddings , dim = 1 )
1386
- grouped_durations = torch .stack (grouped_durations , dim = 0 )
1387
- grouped_aug_embeddings = {
1388
- key : torch .stack (grouped_aug_embeddings [key ], dim = 1 ) for key in grouped_aug_embeddings
1389
- }
1390
-
1391
- return grouped_content_embeddings , grouped_aug_embeddings , grouped_durations
1392
-
1393
- def compute_mean_speaker_embeddings (self , n_embeddings_per_speaker = 50 ):
1394
- print ("computing mean speaker embeddings..." )
1395
- mean_speaker_embeddings = {}
1396
- speaker_counts = {}
1397
- for idx in range (len (self .data )):
1398
- if idx % 1000 == 0 :
1399
- print ("processed" , idx , len (self .data ), "files" )
1400
- sample = self .data [idx ]
1401
- speaker = sample ["speaker" ]
1402
- if speaker in speaker_counts and speaker_counts [speaker ] >= n_embeddings_per_speaker :
1403
- continue
1404
-
1405
- rel_audio_path = Path (sample ["audio_filepath" ]).relative_to (self .base_data_dir ).with_suffix ("" )
1406
- rel_audio_path_as_text_id = str (rel_audio_path ).replace ("/" , "_" )
1407
- speaker_emb_fn = f"speaker_embedding_{ rel_audio_path_as_text_id } .pt"
1408
- speaker_emb_fp = os .path .join (self .sup_data_dir , speaker_emb_fn )
1409
- if os .path .exists (speaker_emb_fp ):
1410
- embedding = torch .load (speaker_emb_fp )
1411
- if speaker not in mean_speaker_embeddings :
1412
- print ("adding speaker" , len (mean_speaker_embeddings ), speaker , speaker_emb_fp )
1413
- mean_speaker_embeddings [speaker ] = embedding
1414
- speaker_counts [speaker ] = 1
1415
- else :
1416
- mean_speaker_embeddings [speaker ] += embedding
1417
- speaker_counts [speaker ] += 1
1418
-
1419
- for speaker in mean_speaker_embeddings :
1420
- mean_speaker_embeddings [speaker ] /= speaker_counts [speaker ]
1421
- l2_norm = torch .norm (mean_speaker_embeddings [speaker ], p = 2 )
1422
- mean_speaker_embeddings [speaker ] /= l2_norm
1423
-
1424
- print ("mean speaker embeddings computed" )
1425
- self .mean_speaker_embeddings = mean_speaker_embeddings
1426
-
1427
1350
def get_ssl_features (self , wav_text_id ):
1428
1351
content_emb_fn = f"{ self .ssl_content_emb_type } _content_embedding_{ wav_text_id } .pt"
1429
1352
speaker_emb_fn = f"speaker_embedding_{ wav_text_id } .pt"
@@ -1446,18 +1369,6 @@ def get_ssl_features(self, wav_text_id):
1446
1369
f"Speaker embedding file { speaker_emb_fp } does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training."
1447
1370
)
1448
1371
1449
- aug_embeddings = {}
1450
- if len (self .content_aug_types ) > 0 :
1451
- for aug_type in self .content_aug_types :
1452
- aug_emb_fn = f"{ self .ssl_content_emb_type } _{ aug_type } _content_embedding_{ wav_text_id } .pt"
1453
- aug_emb_fp = os .path .join (self .sup_data_dir , aug_emb_fn )
1454
- if os .path .exists (aug_emb_fp ):
1455
- aug_embeddings [aug_type ] = torch .load (aug_emb_fp )
1456
- else :
1457
- raise ValueError (
1458
- f"Augmented content embedding file { aug_emb_fp } does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training."
1459
- )
1460
-
1461
1372
if os .path .exists (duration_fp ):
1462
1373
duration = torch .load (duration_fp )
1463
1374
else :
@@ -1467,7 +1378,7 @@ def get_ssl_features(self, wav_text_id):
1467
1378
1468
1379
encoded_len = torch .tensor (content_embedding .shape [1 ]).long ()
1469
1380
1470
- return content_embedding , speaker_embedding , encoded_len , duration , aug_embeddings
1381
+ return content_embedding , speaker_embedding , encoded_len , duration
1471
1382
1472
1383
def get_pitch_contour (self , wav_text_id ):
1473
1384
pitch_contour_fn = f"pitch_contour_{ wav_text_id } .pt"
@@ -1531,14 +1442,6 @@ def pad_collate_fn(self, batch):
1531
1442
duration_padded = torch .nn .functional .pad (duration , (0 , max_encoded_len - duration .size (0 )), value = 0.0 )
1532
1443
durations_padded .append (duration_padded )
1533
1444
1534
- other_content_embedding_keys = [k for k in final_batch if k .startswith ("content_embedding_" )]
1535
- for key in other_content_embedding_keys :
1536
- other_content_embeddings_padded = []
1537
- for encoded in final_batch [key ]:
1538
- encoded_padded = torch .nn .functional .pad (encoded , (0 , max_encoded_len - encoded .size (1 )), value = 0 )
1539
- other_content_embeddings_padded .append (encoded_padded )
1540
- final_batch [key ] = other_content_embeddings_padded
1541
-
1542
1445
final_batch ["audio" ] = audios_padded
1543
1446
final_batch ["mel_spectrogram" ] = mels_padded
1544
1447
final_batch ["pitch_contour" ] = pitch_contours_padded
@@ -1548,7 +1451,7 @@ def pad_collate_fn(self, batch):
1548
1451
for key in final_batch :
1549
1452
final_batch [key ] = torch .stack (final_batch [key ])
1550
1453
1551
- return dict ( final_batch )
1454
+ return final_batch
1552
1455
1553
1456
def __getitem__ (self , index ):
1554
1457
sample = self .data [index ]
@@ -1563,9 +1466,7 @@ def __getitem__(self, index):
1563
1466
if self .pitch_conditioning :
1564
1467
pitch_contour = self .get_pitch_contour (rel_audio_path_as_text_id )
1565
1468
1566
- content_embedding , speaker_embedding , encoded_len , duration , aug_embeddings = self .get_ssl_features (
1567
- rel_audio_path_as_text_id
1568
- )
1469
+ content_embedding , speaker_embedding , encoded_len , duration = self .get_ssl_features (rel_audio_path_as_text_id )
1569
1470
1570
1471
if self .speaker_conditioning_type == "mean" :
1571
1472
assert sample ["speaker" ] in self .mean_speaker_embeddings , "{} not in speaker emb" .format (sample ['speaker' ])
@@ -1586,17 +1487,6 @@ def __getitem__(self, index):
1586
1487
mel_spectrogram = self .get_mel_spectrogram (rel_audio_path_as_text_id )
1587
1488
mel_len = torch .tensor (mel_spectrogram .shape [1 ]).long ()
1588
1489
1589
- alternate_speakers = [spk for spk in self .mean_speaker_embeddings if spk != sample ["speaker" ]]
1590
- if len (alternate_speakers ) == 0 :
1591
- alternate_speaker = sample ["speaker" ]
1592
- else :
1593
- if self .alternate_speaker_conditioning == "random" :
1594
- alternate_speaker = random .choice (alternate_speakers )
1595
- elif self .alternate_speaker_conditioning == "fixed" :
1596
- alternate_speaker = min (alternate_speakers )
1597
-
1598
- alternate_speaker_embedding = self .mean_speaker_embeddings [alternate_speaker ]
1599
-
1600
1490
if pitch_contour is not None :
1601
1491
if self .pitch_normalization in ["speaker_wise" , "global" ]:
1602
1492
mean , std = self .pitch_mean , self .pitch_std
@@ -1625,7 +1515,6 @@ def __getitem__(self, index):
1625
1515
'audio_len' : audio_length ,
1626
1516
'content_embedding' : content_embedding ,
1627
1517
'speaker_embedding' : speaker_embedding ,
1628
- 'alternate_speaker_embedding' : alternate_speaker_embedding ,
1629
1518
'encoded_len' : encoded_len ,
1630
1519
'pitch_contour' : pitch_contour ,
1631
1520
'speaker' : speaker ,
@@ -1635,9 +1524,6 @@ def __getitem__(self, index):
1635
1524
'duration' : duration ,
1636
1525
}
1637
1526
1638
- for aug_type in aug_embeddings :
1639
- item ["content_embedding_{}" .format (aug_type )] = aug_embeddings [aug_type ]
1640
-
1641
1527
return item
1642
1528
1643
1529
def __len__ (self ):
0 commit comments