前言
在一些网络中,需要实时将时域信号转换为频域信号,快速傅里叶变换(Fast Fourier Transform,FFT)是一种常用方法。
本文主要介绍如何用pytorch实现FFT,并封装成一个自定义层,方便在一些需要时频变换的网络中即插即用。
一、构造信号样例
首先构造一个多通道信号,其形状为:(channels,time_samples),然后转换为tensor格式。代码如下:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# 设置信号参数
num_channels = 5 # 信号通道数
num_samples = 600 # 时间样本点
sampling_rate = 200 # 采样率
low_freq = 0 # 选择想要的频率范围下限
high_freq = 80 # # 选择想要的频率范围上限【注:这里并不是滤波】
# 构造一个多通道信号
t = np.linspace(start=0, stop=(num_samples-1)/sampling_rate, num=num_samples) # 时间轴
signals = 3*np.sin(2*np.pi*3*t)+2*np.sin(2*np.pi*15*t)+np.sin(2*np.pi*27*t) # 包含三个分量,幅值分别为3,2,1,频率分别为3,15,27
noise_level = 0.5
noise = np.random.normal(0, noise_level, signals.shape)
signals = signals + noise # 加入噪声
signals = np.tile(signals, (num_channels, 1)) # 复制成多通道信号
print('signals shape:', signals.shape) # (channels,time_samples)
# 转换为tensor格式信号
tensor_signals = torch.tensor(signals, dtype=torch.float32)
二、使用torch.fft模块
torch.fft.fft
和 torch.fft.rfft
都是 PyTorch 库中用于进行快速傅里叶变换(FFT)的函数,但它们适用于不同类型的输入数据和有一些关键的区别:
torch.fft.fft
:- 处理实数或复数输入,并输出复数结果,包含正频和负频部分
- 当输入是实数时,输出是对称的,因此含有冗余信息
- 当输入是复数时,输出是非对称的,完整表示了所有频率的复数幅度和相位信息
torch.fft.rfft
:- 专门用于对实数序列进行傅里叶变换,只处理实数输入
- 输出非负频率的复数结果,从而减少计算复杂度
代码如下:
# 使用PyTorch进行FFT
fft_signals = torch.fft.fft(tensor_signals, dim=1)
rfft_signals = torch.fft.rfft(tensor_signals, dim=1)
print('fft_signals shape:', fft_signals.shape)
print('rfft_signals shape:', rfft_signals.shape)
# 获取频率列表
freqs = torch.fft.fftfreq(num_samples, 1/sampling_rate) # 得到全部频率
rfreqs = torch.fft.rfftfreq(num_samples, 1/sampling_rate) # 得到非负频率
print('freqs length:', len(freqs))
print('freqs list:\n', freqs)
print('rfreqs length:', len(rfreqs))
print('rfreqs list:\n', rfreqs)
# 后续步骤使用freqs
# 选择想要的频率范围
mask = (freqs >= low_freq) & (freqs <= high_freq)
# 选择一个通道的信号进行可视化
channel_idx = 0
selected_channel_signal = signals[channel_idx] # 时域信号
selected_channel_fft = fft_signals[channel_idx][mask] # 选择频率范围内的频域信号
# 绘制原始信号
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(t, selected_channel_signal)
plt.title(f'Original Signal of Channel {channel_idx+1}')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
# 绘制频域信号
plt.subplot(2, 1, 2)
plt.plot(freqs[mask], 2.0/num_samples * torch.abs(selected_channel_fft)) # 2.0/num_samples因子用于标准化
plt.title(f'FFT of Channel {channel_idx+1} ({low_freq}-{high_freq} Hz)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.tight_layout()
plt.show()
这里选择第一个通道,绘图结果如下:
三、构造FFT_Layer
为了便于即插即用,封装成一个自定义层,其中low_freq
和high_freq
用于获取需要的频率范围,方便某些情况下选择期望的频率特征。
代码如下:
# 自定义FFT层
class FFTLayer(nn.Module):
"""
in: (batch,channel,time)
out: (bath,channel,freq_len), (freq_len,)
"""
def __init__(self, sampling_rate, low_freq=None, high_freq=None):
super().__init__()
self.low_freq = low_freq
self.high_freq = high_freq
self.sampling_rate = sampling_rate
def forward(self, x):
num_samples = x.shape[2]
x = torch.fft.fft(x, dim=2)
factor = 2.0
freqs = torch.fft.fftfreq(num_samples, 1 / self.sampling_rate)
if (self.low_freq is None) and (self.high_freq is None):
fft = torch.abs(x)
else:
if (self.low_freq is None) and (self.high_freq is not None):
mask = freqs <= self.high_freq
elif (self.low_freq is not None) and (self.high_freq is None):
mask = freqs >= self.low_freq
else:
mask = (freqs >= self.low_freq) & (freqs <= self.high_freq)
fft = torch.abs(x[..., mask])
freqs = freqs[mask]
if self.low_freq is None or self.low_freq < 0: # 如果包含了负频率,标准化因子就不补偿2,而是1
factor = 1.0
print(factor)
return (factor / num_samples) * fft, freqs
使用该自定义层进行FFT,代码如下:
sig = tensor_signals.unsqueeze(0) # 将(channels,time_samples)转换为(1,channels,time_samples),其中,1代表batch
fft, freq = FFTLayer(sampling_rate=sampling_rate, low_freq=low_freq, high_freq=high_freq)(sig)
print(fft.shape)
print(freq.shape)
# 绘制原始信号
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(t, sig[0, channel_idx, :])
plt.title(f'Original Signal of Channel {channel_idx+1}')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
# 绘制频域信号
plt.subplot(2, 1, 2)
plt.plot(freq, fft[0, channel_idx, :])
plt.title(f'FFT of Channel {channel_idx+1} ({low_freq}-{high_freq} Hz)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.tight_layout()
plt.show()
与前文所绘制的图一致。
另外,如果默认不设置low_freq
和high_freq
,即
fft, freq = FFTLayer(sampling_rate=sampling_rate)(sig)
则绘制全部的正负频率,如下图: