1、Net(input)调用的是什么函数?为什么直接写对象名就直接调用函数了?
net是创建的vgg类的对象,vgg类继承于pytorch库中类nn.Module。创建类时的括号里写上父类的名字,就是继承的意思。
在pytorch库中nn.Module定义如下:
forward : Callable[…, Any] = _unimplemented_forward ,call : Callable[…, Any] = _call_impl
这个冒号意思基本类似于forword函数就是 _unimplemented_forward,只是输入输出都是任意的。
call 就是_call_impl输入输出都是任意的。
源码可以简化为下面的函数:
当子类调用c1(),__call__调用_call_impl(*args),再调用self.forwords,如果子类有forward,那就调用子类的forword,如果子类没有forword,就调用父类的forward,就是_unimplemented_forward,直接报错。
当父类和子类调用同一个函数名 F,子类的对象调用的是子类的F函数。
2、net.train()函数从何而来?
都是父类的nn.Module的函数,由net继承。 net.eval()、net.load_state_dict()都是父类的函数
3、net.eval()有啥用?
如果不写net.eval(),每次模型测试的概率值都不一样。
4、@torch.no_grad() 有啥用?为啥要卸载测试函数前面?
@这样写,就是装饰器,测试函数不算梯度,是为了节省空间。python的装饰器,写在一个函数前面,就是说在执行这个函数的时候,就会执行装饰器里的内容。
其他的重点内容,看下面两篇文章。
《浅谈 PyTorch 中的 tensor 及使用》https://zhuanlan.zhihu.com/p/69294347
《PyTorch 的 Autograd》https://zhuanlan.zhihu.com/p/67184419