问题描述:
在运行如下代码时会报错
U, V = gd_factorise_ad(A,2)
然后就报错了。
RuntimeError: Found dtype Double but expected Float
我试了很多种加float都不行,最后通过资料发现,是计算mse这步出错。
原因分析:
直面意思就是 类型应该是float而不是double
解决方案:
由于我们在利用torch计算mseloss时的张量类型有误而导致的,因此,在计算mse的时候,往里面加入tensor.float()即可
原来的代码如下:
loss_fn = torch.nn