摘要
论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。
设计原理
-
窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。
-
掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。
-
扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。
优势
-
高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。
-
信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。
-
灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。
实验结果
在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。
代码
Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum
def to(x):
return {'device': x.device, 'dtype': x.dtype}
def expand_dim(t, dim, k):
t = t.unsqueeze(dim=dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
def rel_to_abs(x):
b, l, m = x.shape
r = (m + 1) // 2
col_pad = torch.zeros((b, l, 1), **to(x))
x = torch.cat((x, col_pad), dim=2)
flat_x = rearrange(x, 'b l c -> b (l c)')
flat_pad = torch.zeros((b, m - l), **to(x))
flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)
final_x = flat_x_padded.reshape(b, l + 1, m)
final_x = final_x[:, :l, -r:]
return final_x
def relative_logits_1d(q, rel_k):
b, h, w, _ = q.shape
r = (rel_k.shape[0] + 1) // 2
logits = einsum('b x y d, r d -> b x y r', q, rel_k)
logits = rearrange(logits, 'b x y r -> (b x) y r')
logits = rel_to_abs(logits)
logits = logits.reshape(b, h, w, r)
logits = expand_dim(logits, dim=2, k=r)
return logits
class RelPosEmb(nn.Module):
def __init__(
self,
block_size,
rel_size,
dim_head
):
super().__init__()
height = width = rel_size
scale = dim_head ** -0.5
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
block = self.block_size
q = rearrange(q, 'b (x y) c -> b x y c', x=block)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
q = rearrange(q, 'b x y d -> b y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
return rel_logits_w + rel_logits_h
class FixedPosEmb(nn.Module):
def __init__(self, window_size, overlap_window_size):
super().__init__()
self.window_size = window_size
self.overlap_window_size = overlap_window_size
attention_mask_table = torch.zeros((window_size + overlap_window_size - 1),
(window_size + overlap_window_size - 1))
attention_mask_table[0::2, :] = float('-inf')
attention_mask_table[:, 0::2] = float('-inf')
attention_mask_table = attention_mask_table.view(
(window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten_1 = torch.flatten(coords, 1) # 2, Wh*Ww
coords_h = torch.arange(self.overlap_window_size)
coords_w = torch.arange(self.overlap_window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten_2 = torch.flatten(coords, 1)
relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.overlap_window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.overlap_window_size - 1
relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(
1, self.window_size ** 2, self.overlap_window_size ** 2
), requires_grad=False)
def forward(self):
return self.attention_mask
class DilatedOCA(nn.Module):
def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):
super(DilatedOCA, self).__init__()
self.num_spatial_heads = num_heads
self.dim = dim
self.window_size = window_size
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
self.dim_head = dim_head
self.inner_dim = self.dim_head * self.num_spatial_heads
self.scale = self.dim_head ** -0.5
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,
padding=(self.overlap_win_size - window_size) // 2)
self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)
self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
self.rel_pos_emb = RelPosEmb(
block_size=window_size,
rel_size=window_size + (self.overlap_win_size - window_size),
dim_head=self.dim_head
)
self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv(x)
qs, ks, vs = qkv.chunk(3, dim=1)
# spatial attention
qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)
ks, vs = map(lambda t: self.unfold(t), (ks, vs))
ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))
# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')
# split heads
qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),
(qs, ks, vs))
# attention
qs = qs * self.scale
spatial_attn = (qs @ ks.transpose(-2, -1))
spatial_attn += self.rel_pos_emb(qs)
spatial_attn += self.fixed_pos_emb()
spatial_attn = spatial_attn.softmax(dim=-1)
out = (spatial_attn @ vs)
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,
h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)
# merge spatial and channel
out = self.project_out(out)
return out
if __name__ == "__main__":
dim = 64
window_size = 8
overlap_ratio = 0.5
num_heads = 2
dim_head = 16
# 初始化 DilatedOCA 模块
oca_attention = DilatedOCA(
dim=dim,
window_size=window_size,
overlap_ratio=overlap_ratio,
num_heads=num_heads,
dim_head=dim_head,
bias=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
oca_attention = oca_attention.to(device)
print(oca_attention)
x = torch.randn(1, 32, 640, 480).to(device)
# 前向传播
output = oca_attention(x)
print("input张量形状:", x.shape)
print("output张量形状:", output.shape)
DilatedOCA模块详解
代码结构
import torch
import torch.nn as nn
from einops import rearrange
- 导入库:首先导入 PyTorch 和 einops 库。
einops
用于简化张量的重排操作。
模块定义
class DilatedOCA(nn.Module):
def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):
super(DilatedOCA, self).__init__()
self.num_spatial_heads = num_heads
self.dim = dim
self.window_size = window_size
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
self.dim_head = dim_head
self.inner_dim = self.dim_head * self.num_spatial_heads
self.scale = self.dim_head ** -0.5
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,
padding=(self.overlap_win_size - window_size) // 2)
self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)
self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
self.rel_pos_emb = RelPosEmb(
block_size=window_size,
rel_size=window_size + (self.overlap_win_size - window_size),
dim_head=self.dim_head
)
self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
-
初始化方法:
__init__
方法定义了模块的结构。-
dim
:输入特征的通道数。 -
window_size
:窗口的大小,用于空间注意力计算。 -
overlap_ratio
:重叠窗口的比例,决定了窗口之间的重叠程度。 -
num_heads
:空间注意力的头数。 -
dim_head
:每个头的维度。
-
-
层的定义:
-
self.unfold
:用于将输入张量展开为重叠窗口的操作。 -
self.qkv
:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。 -
self.project_out
:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。 -
self.rel_pos_emb
和self.fixed_pos_emb
:用于位置编码的模块,增强模型对空间位置的感知。
-
前向传播
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv(x)
qs, ks, vs = qkv.chunk(3, dim=1)
# spatial attention
qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)
ks, vs = map(lambda t: self.unfold(t), (ks, vs))
ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))
# split heads
qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),
(qs, ks, vs))
# attention
qs = qs * self.scale
spatial_attn = (qs @ ks.transpose(-2, -1))
spatial_attn += self.rel_pos_emb(qs)
spatial_attn += self.fixed_pos_emb()
spatial_attn = spatial_attn.softmax(dim=-1)
out = (spatial_attn @ vs)
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,
h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)
# merge spatial and channel
out = self.project_out(out)
return out
-
输入形状:
x
的形状为(batch_size, channels, height, width)
,其中b
是批量大小,c
是通道数,h
和w
是图像的高度和宽度。 -
特征提取:
-
qkv = self.qkv(x)
:通过qkv
层生成 Q、K、V 特征图。 -
qs, ks, vs = qkv.chunk(3, dim=1)
:将 Q、K、V 特征图沿通道维度分离。
-
-
空间注意力计算:
-
qs
被重排为适合空间注意力计算的格式。 -
ks
和vs
通过unfold
操作展开为重叠窗口。
-
-
分头处理:
- 使用
einops.rearrange
将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
- 使用
-
计算注意力:
-
qs = qs * self.scale
:对 Q 进行缩放以提高稳定性。 -
spatial_attn = (qs @ ks.transpose(-2, -1))
:计算注意力分数。 -
spatial_attn += self.rel_pos_emb(qs)
和spatial_attn += self.fixed_pos_emb()
:添加位置编码以增强空间感知。 -
spatial_attn = spatial_attn.softmax(dim=-1)
:对注意力分数进行 softmax 归一化。
-
-
输出计算:
out = (spatial_attn @ vs)
:使用注意力权重对 V 进行加权求和,得到最终输出。
-
重排输出:
out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...)
:将输出重排回原始形状。
-
最终投影:
out = self.project_out(out)
:通过投影层将输出映射回原始通道数。
总结
DilatedOCA
模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。