Skip to content

Instantly share code, notes, and snippets.

@xvdp
Last active September 12, 2024 09:30
Show Gist options
  • Save xvdp/77d25acfb20a49e44b89f80f0fa2f7c2 to your computer and use it in GitHub Desktop.
Save xvdp/77d25acfb20a49e44b89f80f0fa2f7c2 to your computer and use it in GitHub Desktop.
Fourier 1d Conv is constant speed as kernel support inclreases
"""@xvdp
Fourier 1D convs are constant speed with support size increase. HW optimized Sliding W are faster for small supports <20.
Vec (conv) kernel = ifft(fft(Vec) * fft(kernel))
I used these to RIR (Room Impusle Response) to audio augmentation.
I filed issue to pytorch https://github.com/pytorch/pytorch/issues/79222, I just noticed I had not gisted it.
"""
from typing import Optional
import time
import torch
from torch import Tensor
import torch.nn.functional as F
# pylint: disable=no-member
# pylint: disable=suppressed-message
def fftconv1d(x: Tensor, weight: Tensor,
bias: Optional[Tensor] = None,
padding: int = 0,
groups: int = 1) -> Tensor:
"""
Args
x: Tensor (batch_size, in_channels, size)
weight: Tensor (out_channels, in_channels//groups, kernel_size)
bias: Tensor [None] out_channels
padding int [0]
groups int [1] in_channels, out _channels must be divisible by groups
# stride and dilation = 1
adapted from https://towardsdatascience.com/fourier-convolutions-in-pytorch-4cbd23c70005
faster for large ones
"""
assert x.ndim == 3, "x expedted shape: (N, C, L)"
assert weight.ndim == 3, "weight expected (in_channels, out_channels, kernel)"
_out, _in, _ = weight.shape
if bias is not None:
assert bias.ndim==1 and len(bias) == _out, "bias vector sized as out_channels reqd"
assert not x.shape[1]%groups, f"in_channels must be mod groups {x.shape[1], groups}"
assert not _out%groups, f"out_channels must be mod groups {_out, groups}"
assert x.shape[1] == groups*_in, f"Given groups={groups} and weight {tuple(weight.shape)}, \
expected input {tuple(x.shape)} to have {groups*_in} channels"
out = F.pad(x, [padding, padding])
_pad = out.shape[-1] - weight.shape[-1]
x_rfft = torch.fft.rfftn(out, dim=-1)
w_rfft = torch.fft.rfftn(F.pad(weight, (0, _pad)), dim=-1)
w_rfft.imag *= -1
if groups == 1:
x_rfft = torch.einsum("ab..., cb... -> ac...", x_rfft, w_rfft)
else:
_o = _out//groups
x_rfft = torch.cat([torch.einsum("ab..., cb... -> ac...",
x_rfft[:, _in*g:_in*(g+1)],
w_rfft[_o*g:_o*(g+1)])
for g in range(groups)], dim=1)
out = torch.fft.irfftn(x_rfft, dim=-1)[..., :_pad + 1].contiguous()
if bias is not None:
out = out + bias.view(1, -1, 1)
return out
def _testconv(cuda=True, grad=True, pad=None, out_channels=4, in_channels=2,
batch_size= 20, size = 4096, ksize = 1000, groups=1):
if pad is None:
pad = ksize//2
signal = torch.randn(batch_size, in_channels, size)
if grad:
signal.requires_grad = True
kernel = torch.randn(out_channels, in_channels//groups, ksize)
bias = torch.randn(out_channels)
print(f"\n signal: {tuple(signal.shape)}, kernel: {tuple(kernel.shape)}")
if cuda:
signal = signal.to(device="cuda")
kernel = kernel.to(device="cuda")
bias = bias.to(device="cuda")
_start = time.time()
y0 = F.conv1d(signal, kernel, bias=bias, padding=pad, groups=groups)
if cuda:
torch.cuda.synchronize()
_fconv = time.time()
y2 = fftconv1d(signal, kernel, bias=bias, padding=pad, groups=groups)
if cuda:
torch.cuda.synchronize()
_fftconv = time.time()
_test = f'test: cuda:{cuda}, grad:{grad}, pad{pad}, out:{out_channels}, in{in_channels}, groups{groups}'
print(_test)
_nntime = 1000*(_fconv - _start)
_fftime = 1000*(_fftconv - _fconv)
if _nntime < _fftime:
_nn="\t\t\tnn.Conv1d is faster"
_ff =""
elif _fftime < _nntime:
_nn = ""
_ff = "\t\t\tFFT faster"
print(f" nn.Conv1d() time {1000*(_nntime):.1f} ms {_nn}")
print(f" fftconv1d time {1000*(_fftime):.1f} ms {_ff}")
assert torch.allclose(y0, y2, rtol=1e-3, atol=1e-3), _test
def test_conv_opt():
cuda = [True, False]
grad = [True, False]
padding = [0, None, 100]
groups = [1,2]
out_channels = [4,2]
in_channels = [2,8]
batch_size = 20
size = [4096, 14400]
ksize = [9, 1000]
for p in padding:
for r in grad:
for c in cuda:
for g in groups:
for i in in_channels:
for o in out_channels:
for k in ksize:
for s in size:
_testconv(cuda=c, grad=r, pad=p, out_channels=o, groups=g,
in_channels=i, batch_size=batch_size, size=s, ksize=k)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment