Skip to content

Commit

Permalink
add torch optimized mean and median pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Apr 7, 2023
1 parent 8f4c14e commit 443aa38
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 3 deletions.
43 changes: 43 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import timeit

import torch

from torchcrepe.filter import mean, median, nanfilter, nanmean, nanmedian

###############################################################################
# Test filter.py
###############################################################################


def test_mean():
_deprecated_mean = lambda x, win_length: nanfilter(x, win_length, nanmean)

x = torch.rand(1, 44100)
x[torch.rand_like(x) < 0.1] = float("nan")

assert torch.allclose(mean(x, 3), _deprecated_mean(x, 3), equal_nan=True)
assert torch.allclose(mean(x, 9), _deprecated_mean(x, 9), equal_nan=True)

# time_mean = timeit.timeit(lambda: mean(x, 3), number=10)
# time_deprecated_mean = timeit.timeit(lambda: _deprecated_mean(x, 3), number=10)

# print(
# f"mean: {time_mean}, deprecated_mean: {time_deprecated_mean}, speed: {time_deprecated_mean / time_mean}x"
# )


def test_median():
_deprecated_median = lambda x, win_length: nanfilter(x, win_length, nanmedian)

x = torch.rand(1, 44100)
x[torch.rand_like(x) < 0.1] = float("nan")

assert torch.allclose(median(x, 3), _deprecated_median(x, 3), equal_nan=True)
assert torch.allclose(median(x, 9), _deprecated_median(x, 9), equal_nan=True)

# time_median = timeit.timeit(lambda: median(x, 3), number=10)
# time_deprecated_median = timeit.timeit(lambda: _deprecated_median(x, 3), number=10)

# print(
# f"median: {time_median}, deprecated_median: {time_deprecated_median}, speed: {time_deprecated_median / time_median}x"
# )
75 changes: 72 additions & 3 deletions torchcrepe/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch

from torch.nn import functional as F

###############################################################################
# Sequence filters
Expand All @@ -19,7 +19,41 @@ def mean(signals, win_length=9):
Returns
filtered (torch.tensor (shape=(batch, time)))
"""
return nanfilter(signals, win_length, nanmean)

assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
signals = signals.unsqueeze(1)

# Apply the mask by setting masked elements to zero, or make NaNs zero
mask = ~torch.isnan(signals)
masked_x = torch.where(mask, signals, torch.zeros_like(signals))

# Create a ones kernel with the same number of channels as the input tensor
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)

# Perform sum pooling
sum_pooled = F.conv1d(
masked_x,
ones_kernel,
stride=1,
padding=win_length // 2,
)

# Count the non-masked (valid) elements in each pooling window
valid_count = F.conv1d(
mask.float(),
ones_kernel,
stride=1,
padding=win_length // 2,
)
valid_count = valid_count.clamp(min=1) # Avoid division by zero

# Perform masked average pooling
avg_pooled = sum_pooled / valid_count

# Fill zero values with NaNs
avg_pooled[avg_pooled == 0] = float("nan")

return avg_pooled.squeeze(1)


def median(signals, win_length):
Expand All @@ -34,7 +68,42 @@ def median(signals, win_length):
Returns
filtered (torch.tensor (shape=(batch, time)))
"""
return nanfilter(signals, win_length, nanmedian)

assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
signals = signals.unsqueeze(1)

mask = ~torch.isnan(signals)
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
padding = win_length // 2

x = F.pad(masked_x, (padding, padding), mode="reflect")
mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)

x = x.unfold(2, win_length, 1)
mask = mask.unfold(2, win_length, 1)

x = x.contiguous().view(x.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,))

# Combine the mask with the input tensor
x_masked = torch.where(mask.bool(), x, float("inf"))

# Sort the masked tensor along the last dimension
x_sorted, _ = torch.sort(x_masked, dim=-1)

# Compute the count of non-masked (valid) values
valid_count = mask.sum(dim=-1)

# Calculate the index of the median value for each pooling window
median_idx = ((valid_count - 1) // 2).clamp(min=0)

# Gather the median values using the calculated indices
median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)

# Fill infinite values with NaNs
median_pooled[torch.isinf(median_pooled)] = float("nan")

return median_pooled.squeeze(1)


###############################################################################
Expand Down

0 comments on commit 443aa38

Please sign in to comment.