Skip to content

Commit

Permalink
Support exporting ONNX
Browse files Browse the repository at this point in the history
  • Loading branch information
yqzhishen committed Jul 10, 2023
1 parent 86bbfbe commit de0ab29
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import argparse
from typing import Dict, Tuple, Union

import torch.nn as nn
import torch.nn.functional as F
from librosa.filters import mel

from src import E2E0
from src.constants import *

import torch


class MelSpectrogram_ONNX(nn.Module):
def __init__(
self,
n_mel_channels,
sampling_rate,
win_length,
hop_length,
n_fft=None,
mel_fmin=0,
mel_fmax=None,
clamp=1e-5
):
super().__init__()
n_fft = win_length if n_fft is None else n_fft
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = win_length if n_fft is None else n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.clamp = clamp

def forward(self, audio, center=True):
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=torch.hann_window(self.win_length, device=audio.device),
center=center,
return_complex=False
)
magnitude = torch.sqrt(torch.sum(fft ** 2, dim=-1))
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
return log_mel_spec


class RMVPE_ONNX(nn.Module):
def __init__(self, hop_length):
super().__init__()
self.model = E2E0(4, 1, (2, 2))
self.mel_extractor = MelSpectrogram_ONNX(
N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX
)

def mel2hidden(self, mel):
n_frames = mel.shape[-1]
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
hidden = self.model(mel) # [B, T, N]
return hidden[:, :n_frames]

# noinspection PyMethodMayBeStatic
def decode(self, hidden, threshold=0.03):
idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
idx_cents = idx * 20 + CONST # [B=1, N]
center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
start = torch.clip(center - 4, min=0) # [B, T, 1]
end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
idx_mask = (idx >= start) & (idx < end) # [B, T, N]
weights = hidden * idx_mask # [B, T, N]
product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
weight_sum = torch.sum(weights, dim=2) # [B, T]
cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
f0 = 10 * 2 ** (cents / 1200)
uv = hidden.max(dim=2)[0] < threshold # [B, T]
return f0 * ~uv, uv

def forward(self, waveform, threshold):
mel = self.mel_extractor(waveform, center=True)
hidden = self.mel2hidden(mel)
f0, uv = self.decode(hidden, threshold=threshold)
return f0, uv


def parse_args(args=None, namespace=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="path to the model checkpoint",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
help="path to the output onnx file",
)
parser.add_argument(
"-hop",
"--hop_length",
type=str,
required=False,
default=160,
help="hop_length under 16khz sampling rate | default: 160",
)
parser.add_argument(
"--optimize",
action="store_true",
help="whether to optimize the generated ONNX graph"
)
return parser.parse_args(args=args, namespace=namespace)


def onnx_override_io_shapes(
model, # ModelProto
input_shapes: Dict[str, Tuple[Union[str, int]]] = None,
output_shapes: Dict[str, Tuple[Union[str, int]]] = None,
):
"""
Override the shapes of inputs/outputs of the model graph (in-place operation).
:param model: model to perform the operation on
:param input_shapes: a dict with keys as input/output names and values as shape tuples
:param output_shapes: the same as input_shapes
"""
def _override_shapes(
shape_list_old, # RepeatedCompositeFieldContainer[ValueInfoProto]
shape_dict_new: Dict[str, Tuple[Union[str, int]]]):
for value_info in shape_list_old:
if value_info.name in shape_dict_new:
name = value_info.name
dims = value_info.type.tensor_type.shape.dim
assert len(shape_dict_new[name]) == len(dims), \
f'Number of given and existing dimensions mismatch: {name}'
for i, dim in enumerate(shape_dict_new[name]):
if isinstance(dim, int):
dims[i].dim_param = ''
dims[i].dim_value = dim
else:
dims[i].dim_value = 0
dims[i].dim_param = dim

if input_shapes is not None:
_override_shapes(model.graph.input, input_shapes)
if output_shapes is not None:
_override_shapes(model.graph.output, output_shapes)


def export():
cmd = parse_args()

model_path = cmd.model
output_path = cmd.output

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('loading model and audio')
rmvpe = RMVPE_ONNX(hop_length=cmd.hop_length)
rmvpe.model.load_state_dict(torch.load(model_path)['model'])
rmvpe.eval().to(device)
waveform = torch.randn(1, 114514, dtype=torch.float32, device=device).clip(min=-1., max=1.)
threshold = torch.tensor(0.03, dtype=torch.float32, device=device)
print('start exporting ...')
with torch.no_grad():
torch.onnx.export(
rmvpe,
(
waveform,
threshold,
),
output_path,
input_names=[
'waveform',
'threshold'
],
output_names=[
'f0',
'uv'
],
dynamic_axes={
'waveform': {
1: 'n_samples'
},
'f0': {
1: 'n_frames'
},
'uv': {
1: 'n_frames'
}
},
opset_version=17
)
if cmd.optimize:
import onnx
import onnxsim
print('start optimizing ...')
model = onnx.load(output_path)
onnx_override_io_shapes(model, output_shapes={
'f0': (1, 'n_frames'),
'uv': (1, 'n_frames'),
})
model, check = onnxsim.simplify(
model,
include_subgraph=True
)
assert check, 'Simplified ONNX model could not be validated'
onnx.save(model, output_path)


if __name__ == '__main__':
export()

0 comments on commit de0ab29

Please sign in to comment.