-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy patheval.py
executable file
·95 lines (80 loc) · 3.49 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
import os
import argparse
import torch
import librosa
import time
from scipy.io.wavfile import write
from tqdm import tqdm
import utils
from contentvec import HubertModelWithFinalProj
from models import SynthesizerTrn
from mel_processing import mel_spectrogram_torch
import logging
logging.getLogger('numba').setLevel(logging.WARNING)
import torch.autograd.profiler as profiler
## check device
if torch.cuda.is_available() is True:
device = "cuda:0"
else:
device = "cpu"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hpfile", type=str, default="logs/quickvc/config.json", help="path to json config file")
parser.add_argument("--ptfile", type=str, default="logs/quickvc/quickvc.pth", help="path to pth file")
parser.add_argument("--txtpath", type=str, default="dataset/eval.txt", help="path to txt file")
parser.add_argument("--outdir", type=str, default="eval", help="path to output dir")
parser.add_argument("--use_timestamp", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)
hps = utils.get_hparams_from_file(args.hpfile)
print("Loading model...")
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model).to(device)
_ = net_g.eval()
total = sum([param.nelement() for param in net_g.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
print("Loading checkpoint...")
_ = utils.load_checkpoint(args.ptfile, net_g, None)
print(f"Loading contentvec checkpoint")
contentvec_extractor = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best").to(device).eval()
print("Processing text...")
allfiles, titles, srcs, tgts = [], [], [], []
allfiles = open(args.txtpath, "r").readlines()
for i in range(len(allfiles) // 2):
tgt = allfiles[i].strip()
src = allfiles[-i].strip()
title = src.split("/")[-1][:-4] + "_" + tgt.split("/")[-1][:-4]
titles.append(title)
srcs.append(src)
tgts.append(tgt)
print("Synthesizing...")
with torch.no_grad():
for line in tqdm(zip(titles, srcs, tgts)):
title, src, tgt = line
# tgt
wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
mel_tgt = mel_spectrogram_torch(
wav_tgt,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
)
wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
c = contentvec_extractor.extract(wav_src).transpose(2,1)
audio = net_g.infer(c, mel=mel_tgt)
audio = audio[0][0].data.cpu().float().numpy()
if args.use_timestamp:
timestamp = time.strftime("%m-%d_%H-%M", time.localtime())
write(os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+title)), hps.data.sampling_rate, audio)
else:
write(os.path.join(args.outdir, f"{title}.wav"), hps.data.sampling_rate, audio)