文章目录
Asymmetric Mask Scheme for Self-Supervised Real Image Denoising
1.bsn 的特点
首先由提取特征得到feature map
然后blind spot conv,提取非中心的特征
然后必须要用dilated conv 才能使网络不学习 恒等信息。
可以查看ap-bsn 论文的net理解。
2.mae的灵感
由于bsn的特点是要使用一些被限制的滤波器,因此网络设计要受很多限制。
在MAE中即使图像被mask,仍然有可能被恢复,因此作者设计了mask based net.
直接对图像进行mask然后恢复,避免了bsn的限制,bsn可以用一般网络替代。
3.解决空间相关性
也是利用 pixel downsampling, 只不过增加了一个mask
4.asymmetric mask scheme
训练的时候只使用一个branch网络
推理的时候也是一个branch网络,只不过输入2个互补的mask图像。
推理的时候mask满足, 多个分支 所有被mask的像素 构成整幅图像
5.shuffle带来的棋盘效应
引入新的损失函数进行fine tune
6.代码
6.1.图像分解 和 组合
利用pixel_shuffle和pixel_unshuffle函数, 对应figure5的上部分
def pd_down(x: torch.Tensor, pd_factor: int = 5, pad: int = 0) -> torch.Tensor:
b, c, h, w = x.shape
x_down = F.pixel_unshuffle(x, pd_factor)
out = x_down.view(b, c, pd_factor, pd_factor, h // pd_factor, w // pd_factor).permute(
0, 2, 3, 1, 4, 5).reshape(b * pd_factor * pd_factor, c, h // pd_factor, w // pd_factor)
return out
def pd_up(out: torch.Tensor, pd_factor: int = 5, pad: int = 0) -> torch.Tensor:
b, c, h, w = out.shape
# Reshape the output tensor to its original shape after pixel unshuffle
x_down = out.view(b // (pd_factor ** 2), pd_factor, pd_factor, c, h,
w).permute(0, 3, 1, 2, 4, 5)
x_down = x_down.reshape(b // (pd_factor ** 2), c *
pd_factor * pd_factor, h, w)
# Use pixel shuffle to upsample the tensor
x_up = F.pixel_shuffle(x_down, pd_factor)
return x_up
6.2训练的流程和 测试的流程
class MultiMaskPdDn(nn.Module):
def __init__(self, pd_train: int = 5, pd_val: int = 2, dn_net: str = 'default', r3: float = -1, r3_num: int = 8,
net_param: dict[str, float | str | int] = None, **kwargs):
super().__init__()
self.dn = dn_dict[dn_net](**net_param if net_param is not None else {})
self.pd_train = pd_train
self.pd_val = pd_val
self.r3 = R3(r3, r3_num)
def denoise(self, x: torch.Tensor, pd_factor: int = None, return_mask: bool = False, only_first: bool = True) -> torch.Tensor:
# 下采样,训练,上采样
if pd_factor is None:
pd_factor = self.pd_train
if pd_factor > 1:
x = util.pd_down(x, pd_factor)
if return_mask:
dn_img, masks = self.dn(x, True, only_first)
else:
dn_img = self.dn(x, False, only_first)
if pd_factor > 1:
dn_img = util.pd_up(dn_img, pd_factor)
return dn_img if not return_mask else (dn_img, masks)
def forward(self, x: torch.Tensor, pd_factor: int = None, return_mask: bool = False, only_first: bool = True) -> torch.Tensor:
if self.training:
return self.denoise(x, pd_factor, return_mask, only_first)
else:
# 测试相比训练多了个r3
denoised = self.denoise(x, self.pd_val, only_first=False)
return self.r3(x, denoised, self.dn)
6.3 mask
对应于公式5
假如有3个mask, mask_index包括随机的0,1,2
masks=0表示被mask的区域
res 表示被mask的区域设置为0
也就是n个mask的情况下,mask=0的区域占n分之一.
class MultiScaleMask(nn.Module):
def __init__(self, scale_num: int = 2):
super().__init__()
self.scale_num = scale_num
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if (len(x.shape) == 3):
x = x.unsqueeze(0)
b, c, h, w = x.shape
# 0到scale_num-1 这个区间的数填充整个图像
mask_index = torch.randint(
0, self.scale_num, (b, 1, h, w)).expand(-1, c, -1, -1).to(x.device)
res = torch.zeros(self.scale_num, *x.shape).to(x.device)
masks = torch.BoolTensor(self.scale_num, *x.shape).to(x.device)
for i in range(self.scale_num):
temp = mask_index != i
masks[i] = temp
res[i] = x*temp
return res, masks
默认2个mask
每个mask的shape 是 bchw
被mask的图像输入到 branch中,其实只有一个branch
两种mask的图像输入到branchnet中得到 out
然后out只取被mask的元素
def forward(self, x: torch.Tensor, return_mask: bool = False, only_first: bool = False) -> torch.Tensor:
masked_img, masks = self.mask(x)
dn_img = torch.zeros_like(x).to(dtype=torch.float32)
order_len = min(self.mask_num, len(self.branches_order))
for i, j in zip(self.branches_order, [_ for _ in range(order_len)]):
out = self.branches[i](masked_img[j])
dn_img[~masks[j]] = out[~masks[j]]
# 训练的时候只使用一个mask, 反正是随机的mak
if only_first and return_mask:
break
return dn_img if not return_mask else (dn_img, masks)
6.4 loss函数
训练的时候 only_first=true, 因此只有一个mask其作用,其实就是随机mask 50%像素建立损失,进行训练
但是推理的时候用到多个mask, 所有被mask的像素是组成真个图像尺寸
默认2个mask互补,降噪后的图像被mask的区域 互补 组成完整的denoised image, 参看6.3
总的来说,就是用未被mask的像素预测mask的像素。
class MaskLoss(nn.Module):
def __init__(self, loss_type: str = 'l1') -> None:
super().__init__()
self.loss = losses_dict[loss_type]()
def forward(self, input: torch.Tensor, output: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
total_loss = 0
for mask in masks:
total_loss += self.loss(input[~mask], output[~mask])
return total_loss
# 只对第一个分支建立损失
class FirstBranchMaskLoss(nn.Module):
def __init__(self, loss_type: str = 'l1') -> None:
super().__init__()
self.loss = losses_dict[loss_type]()
def forward(self, input: torch.Tensor, output: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
total_loss = 0
for mask in masks:
total_loss += self.loss(input[~mask], output[~mask])
break
return total_loss
6.5 r3 增强
这个代码要求 batchsize==1
class R3(nn.Module):
def __init__(self, r3: float = -1, r3_num: int = 8):
super().__init__()
self.r3 = r3
self.r3_num = r3_num
if r3 <= 0:
self.enhance = AsymMaskEnhance()
def forward(self, x: torch.Tensor, denoised: torch.Tensor, net: nn.Module) -> torch.Tensor:
if self.r3 > 0:
return util.r3(x, denoised, net, self.r3, self.r3_num)
else:
return denoised
# return self.enhance(x, denoised, net)
def r3(x: torch.Tensor, denoised: torch.Tensor, net: nn.Module, r3_factor: float = 0.16, r3_num: int = 8,
p: int = 0) -> torch.Tensor:
"""random replacement Refinement with ratio r3_factor
Note:
This module is only used in eval, not in train. val will take r3_num times longer than train.
Args:
x(torch.Tensor): input tensor.BCHW,B=1
net(nn.Module): model to eval
r3_factor (float, optional): the ratio of radnom replace. Defaults to 0.16.
r3_num (int, optional): the number of r3 times. Defaults to 8.
Output: BCHW
"""
b, c, h, w = x.shape
temp_input = denoised.expand(r3_num, -1, -1, -1)
x = x.expand(r3_num, -1, -1, -1).to(dtype=torch.float32)
indices = torch.zeros(r3_num, c, h, w, dtype=torch.bool, device=x.device)
for t in range(r3_num):
indices[t] = (torch.rand(1, h, w) < r3_factor).repeat(3, 1, 1)
# 16%的像素 denoised 被替换为 x
temp_input = temp_input.clone()
temp_input[indices] = x[indices]
temp_input = F.pad(temp_input, (p, p, p, p), mode='reflect')
# 然后输入到net,再平均。
with torch.no_grad():
if p == 0:
denoised = net(temp_input)
else:
denoised = net(temp_input)[:, :, p:-p, p:-p]
return torch.mean(denoised, dim=0).unsqueeze(0)
6.6 smooth增强
class TVLoss(torch.nn.L1Loss):
"""Weighted TV loss.
Args:
reduction (str): Loss method. Default: mean.
"""
def __init__(self, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
super(TVLoss, self).__init__(reduction=reduction)
def forward(self, pred):
y_diff = super().forward(
pred[:, :, :-1, :], pred[:, :, 1:, :])
x_diff = super().forward(
pred[:, :, :, :-1], pred[:, :, :, 1:])
loss = x_diff + y_diff
return loss
7.实验
原论文resformer 训练,我替换为unet32
制作自己的数据集
pda=5,pdb=2 结果会丢失细节
pda=2,pdb=2 结果会丢失细节
修改方法:loss引入全图的,而不只是mask的。推理的时候 2个predict 相加后除以2.
相当于原图 mask 像素后 input model, 然后output与 原图建立损失