浅谈torch.nn库和torch.nn.functional库
这两个库很类似,都涵盖了神经网络的各层操作,只是用法有点不同,
nn下是类实现,nn.functional下是函数实现。
conv1d
- 在nn下是一个类,一般继承nn.module通过定义forward()函数计算其值
class Conv1d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)
def forward(self, input):
return torch.nn.functional.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
- 在nn.functional下直接传入参数即可使用,其会直接返回一个torch.nn.functional的函数,和上面类中的forward()中的函数一致
def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1):
if input is not None and input.dim() != 3:
raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))
f = ConvNd(_single(stride), _single(padding), _single(dilation), False,
_single(0), groups, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
- nn.Xxx不需要自己定义和管理参数weight;而nn.functional.xxx需要自己定义weight,每次调用的时候都需要手动传入weight
同样droptout用nn定义的话在训练时生效,在eval()时无效。
损失函数Loss(交叉熵)
- nn库
import torch
import torch.nn as nn
Loss = nn.BCELoss()
a = torch.ones(2,2)
b = torch.ones(2,2)
c = Loss(a,b)
- nn.functional库
import torch
import torch.nn.functional as nn
a = torch.ones(2,2)
b = torch.ones(2,2)
c = nn.binary_cross_entropy(a,b)
c的结果都一样为0,即两个分布高度相似
总结一下,两个库都可以实现神经网络的各层运算。其他包括卷积、池化、padding、激活(非线性层)、线性层、正则化层、其他损失函数Loss,两者都可以实现
nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。因此nn.Xxx除了具有nn.functional.xxx功能(通过类中的forward方法实现),内部附带了nn.Module相关的属性和方法,例如train(), eval(),load_state_dict, state_dict 等,可以自动管理各层的参数,同时还可以实现如Sequential()将多个运算层组合为一个逻辑层。
参考
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.functional.html#
https://www.zhihu.com/question/66782101