Skip to content

Commit

Permalink
网络参数调试, 数据处理修正
Browse files Browse the repository at this point in the history
  • Loading branch information
see2023 committed Jan 5, 2024
1 parent ffe88f8 commit 3d4eb9b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 21 deletions.
2 changes: 1 addition & 1 deletion configs/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"train": {
"log_interval": 20,
"eval_interval": 100,
"eval_interval": 1,
"seed": 42,
"epochs": 1000,
"learning_rate": 0.0002,
Expand Down
45 changes: 34 additions & 11 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from utils import load_wav_to_torch, load_filepaths_and_text
from text import cleaned_text_to_sequence
from config import config
import motion.const_map as const_map

"""Multi speaker version"""

Expand Down Expand Up @@ -422,8 +423,8 @@ def __init__(self, audio_visemes_list_file, hparams):
print('audio_visemes_list_items: ', len(self.audio_visemes_list_items))
random.seed(1234)
random.shuffle(self.audio_visemes_list_items)
self.max_visemes_len = 1200
self.min_visemes_len = 1190
self.max_visemes_len = 1210
self.min_visemes_len = 1180
self._filter()


Expand All @@ -449,23 +450,45 @@ def _filter(self):
def __getitem__(self, index):
# read these two torch.tensor
audio_file, visemes_file = self.audio_visemes_list_items[index]
audio = torch.load(audio_file).squeeze(0).detach()
audio_z = torch.load(audio_file).squeeze(0).detach()
# [192, seq_len(1722)]

visemes = np.load(visemes_file)
visemes = torch.from_numpy(visemes)
if visemes.shape[0] > self.max_visemes_len:
#[seq_len(1194), 61]
visemes = visemes.transpose(0, 1)
#[61, seq_len(1194)]
if visemes.shape[1] > self.max_visemes_len:
# cut the extra part
# print('__getitem__ 1 cut visemes from ', visemes.shape[0], ' to ', self.max_visemes_len, 'file: ', visemes_file)
visemes = visemes[:self.max_visemes_len]
elif visemes.shape[0] < self.max_visemes_len:
visemes = visemes[:, :self.max_visemes_len]
elif visemes.shape[1] < self.max_visemes_len:
# padding to max_visemes_len with last frame
# print('__getitem__ 2 padding visemes from ', visemes.shape[0], ' to ', self.max_visemes_len, 'file: ', visemes_file)
last_frame = visemes[-1]
visemes = np.concatenate([visemes, np.tile(last_frame, (self.max_visemes_len - visemes.shape[0], 1))], axis=0)
visemes = torch.from_numpy(visemes)
# last_frame = visemes[-1]
# visemes = np.concatenate([visemes, np.tile(last_frame, (self.max_visemes_len - visemes.shape[0], 1))], axis=0)
# visemes = torch.from_numpy(visemes)
pass

visemes_offset = 0.02 # 将visemes延迟n s
visemes_offset_frames = int(visemes_offset * const_map.ARKIT_FPS)
visemes = visemes[:, visemes_offset_frames:]

audio_z_offset = 0.0
audio_z_offset_frames = int(audio_z_offset * const_map.Z_FPS)
audio_z = audio_z[:, audio_z_offset_frames:]

# 获取二者的时长,将过长的一方多的部分丢弃
visemes_duration = visemes.shape[1] / const_map.ARKIT_FPS
audio_z_duration = audio_z.shape[1] / const_map.Z_FPS
if visemes_duration > audio_z_duration:
visemes = visemes[:, :int(audio_z_duration * const_map.ARKIT_FPS)]
elif visemes_duration < audio_z_duration:
audio_z = audio_z[:, :int(visemes_duration * const_map.Z_FPS)]


visemes = visemes.transpose(0, 1)
# print('__getitem__ 3 audio.shape: ', audio.shape, 'visemes.shape: ', visemes.shape,'file: ', visemes_file)
return audio, visemes
return audio_z, visemes

def __len__(self):
return len(self.audio_visemes_list_items)
Expand Down
16 changes: 10 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,21 +1082,25 @@ def get_post_enc_dec(self):

class VisemesNet(nn.Module):
def active(self, x):
# active_fun: 0: null, 1: tanh, 2: relu
# active_fun: 0: null, 1: tanh, 2: relu, 3: LeakyReLU
if self.active_fun == 1:
return torch.tanh(x)
elif self.active_fun == 2:
return torch.relu(x)
elif self.active_fun == 3:
return self.leakyReLU(x)
else:
return x

def __init__(self, hidden_channels, lstm_bidirectional=True, active_fun = 2, enable_conv=True,
def __init__(self, hidden_channels, lstm_bidirectional=True, active_fun = 3, enable_conv=True,
use_transformer = False, enable_dropout=True):
super(VisemesNet, self).__init__()
self.lstm_bidirectional = lstm_bidirectional
self.lstm_directions = 2 if lstm_bidirectional else 1
self.use_transformer = use_transformer
self.enable_dropout = enable_dropout
if active_fun == 3:
self.leakyReLU = nn.LeakyReLU(negative_slope=0.01)
if use_transformer:
num_heads=8
num_layers=3
Expand Down Expand Up @@ -1140,20 +1144,20 @@ def forward_transformer(self, x, y=None):
# x [batch_size, hidden_channels, seq_len]
if self.enable_conv:
x = self.conv1d_pre(x)
x = x.permute(2, 0, 1) # Transformer encoder expects [seq_len, batch_size, features]
# batch_first: True (batch, seq, feature); False (seq, batch, feature).
x = x.transpose(1, 2)

expressions = self.transformer_encoder(x)

expressions = expressions.permute(1, 0, 2) # [batch_size, features, seq_len]
if self.enable_dropout:
expressions = self.dropout(expressions)
expressions = self.fc1(expressions)
expressions = self.active(expressions)
# expressions = self.active(expressions)
if self.enable_dropout:
expressions = self.dropout(expressions)
expressions = self.fc2(expressions)

expressions = expressions.transpose(1, 2) # [batch_size, seq_len, features]
expressions = expressions.transpose(1, 2)
if self.enable_conv:
expressions = self.conv1d_post(expressions)

Expand Down
5 changes: 4 additions & 1 deletion motion/const_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@

ARKIT_COUNT = 61
VALID_ARKIT_COUNT = 52
ARKIT_FPS = 60
Z_FPS = 86.1328125

# 合法值在 0-1 之间的下标(序号-1), 名字中没有left right up down的表情
g_positive_index = [14, 17, 18, 19, 20, 46, 51]
g_max_value_groups = [
Expand All @@ -94,7 +97,7 @@
]


def map_arkit_values(bs_weight_arkit, mirror=True):
def map_arkit_values(bs_weight_arkit, mirror=False):
# input: n * 116 float array
weights = np.zeros((bs_weight_arkit.shape[0], ARKIT_COUNT))
for r in range(bs_weight_arkit.shape[0]):
Expand Down
4 changes: 3 additions & 1 deletion motion/tts2ue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def play_and_send(udp_sender, bs_npy_file, wav_file, fps):
exit(1)
bs_npy_file = sys.argv[1]
wav_file = sys.argv[2]
fps = sys.argv[3]
fps = 86.1328125
if sys.argv.__len__() > 3:
fps = sys.argv[3]
# fps to float
fps = float(fps)
if not os.path.exists(bs_npy_file):
Expand Down
25 changes: 24 additions & 1 deletion train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import gc

logging.getLogger("numba").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
import commons
import utils
from data_utils import (
Expand Down Expand Up @@ -128,7 +129,9 @@ def eval_visemes_only(epoch, hps, net_v, eval_loader):
# print('visemes_hat_mse', visemes_hat_mse)
break
visemes_hat_mse_avg = visemes_hat_mse_sum / (batch_idx + 1)
print('------------------ eval visemes_hat_mse_avg: ', visemes_hat_mse_avg)
log_str = '------------------ eval epoch: {} visemes_hat_mse_avg: {:.6f}'.format(epoch, visemes_hat_mse_avg)
print(log_str)
logger.warning(log_str)
net_v.train()


Expand Down Expand Up @@ -189,6 +192,7 @@ def run():
os.makedirs(model_dir)
hps = utils.get_hparams_from_file(args.config)
hps.model_dir = model_dir
set_logger(hps)
if args.visemes:
run_only_visemes(hps)
# 比较路径是否相同
Expand Down Expand Up @@ -921,5 +925,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
generator.train()


def set_logger(hps):
# set logger to file and stdout, using hps.model_dir as logging path
log_format = logging.Formatter(
"%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# train.datetime.log
fh = logging.FileHandler(
os.path.join(
hps.model_dir, "train.{}.log".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
),
mode="a",
)
fh.setLevel(logging.INFO)
fh.setFormatter(log_format)
logger.addHandler(fh)



if __name__ == "__main__":

run()

0 comments on commit 3d4eb9b

Please sign in to comment.