Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Documentation for Audio Processing Functions #2287

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 86 additions & 42 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,48 @@
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Loads an audio file as a mono waveform, resampling to the specified sample rate.

Parameters
----------
file: str
The audio file to open

sr: int
The sample rate to resample the audio if necessary
file : str
Path to the audio file.
sr : int, optional
Target sample rate for resampling, defaults to SAMPLE_RATE.

Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
np.ndarray
1D NumPy array of the audio waveform, normalized between -1 and 1.

Raises
------
RuntimeError
If the audio cannot be loaded.

Notes
-----
Requires ffmpeg installed and accessible in the system's PATH.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off

cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
"ffmpeg", # Command to run the ffmpeg tool.
"-nostdin", # Prevents ffmpeg from reading from stdin.
"-threads", "0", # Uses all available CPU cores for processing.
"-i", file, # Specifies the input file path.
"-f", "s16le", # Sets the output format to 16-bit PCM.
"-ac", "1", # Converts audio to mono (1 channel).
"-acodec", "pcm_s16le", # Specifies the audio codec as PCM signed 16-bit little-endian.
"-ar", str(sr), # Resamples the audio to the specified sample rate.
"-" # Outputs the processed audio to stdout.
]

# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
Expand All @@ -63,8 +72,22 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):

"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
Pads or trims the input array to a specified length along the given axis.

Parameters:
- array: Input array (torch.Tensor or np.ndarray).
- length: Target length along the specified axis (default is N_SAMPLES).
- axis: Axis to pad or trim (default is -1 for the last axis).

Returns:
- The modified array, either padded with zeros or trimmed to the target length.

Note:
- The function handles both PyTorch tensors and NumPy arrays, applying appropriate methods
for padding and trimming depending on the array type.

"""
if torch.is_tensor(array):
if array.shape[axis] > length:
Expand All @@ -91,14 +114,30 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:

np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
Loads a precomputed Mel filterbank matrix for converting STFT to a Mel spectrogram.

Parameters
----------
device : torch.device
The device (CPU or GPU) to load the tensor onto.
n_mels : int
The number of Mel bands, must be either 80 or 128.

Returns
-------
torch.Tensor
A tensor containing the Mel filterbank matrix.

Raises
------
AssertionError
If `n_mels` is not supported.

Notes
-----
The Mel filterbank matrices are saved in a compressed npz file, which decouples
the dependency on librosa for generating these filters.

"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

Expand All @@ -114,27 +153,32 @@ def log_mel_spectrogram(
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Computes the log-Mel spectrogram of an audio waveform.

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 is supported

padding: int
Number of zero samples to pad to the right

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
audio : Union[str, np.ndarray, torch.Tensor]
The audio input, either as a file path, NumPy array, or Torch tensor.
The waveform should be in 16 kHz.
n_mels : int, optional
The number of Mel-frequency filters, only 80 is supported. Defaults to 80.
padding : int, optional
Number of zero samples to pad at the end of the audio. Defaults to 0.
device : Optional[Union[str, torch.device]], optional
The device to perform computations on. If provided, the audio tensor is moved
to this device. Defaults to None.

Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
torch.Tensor
A tensor containing the Mel spectrogram with shape (80, n_frames).

Notes
-----
The function expects a 16 kHz sampling rate for the input audio and uses a Hann
window for the Short-Time Fourier Transform (STFT).
"""

if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
Expand Down