Skip to content

Commit

Permalink
Debug BF16 and RMVPE
Browse files Browse the repository at this point in the history
  • Loading branch information
ylzz1997 committed Jul 11, 2023
1 parent 3678712 commit f808d8e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if writers is not None:
writer, writer_eval = writers

half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16
half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16

# train_loader.batch_sampler.set_epoch(epoch)
global global_step
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
elif f0_predictor == "rmvpe":
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"])
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
else:
raise Exception("Unknown f0 predictor")
return f0_predictor_object
Expand Down

0 comments on commit f808d8e

Please sign in to comment.