Bootstrap

报错记录:Pytorch: RuntimeError: Found dtype Double but expected Float

问题描述:

在运行如下代码时会报错

U, V = gd_factorise_ad(A,2)

然后就报错了。
RuntimeError: Found dtype Double but expected Float

我试了很多种加float都不行,最后通过资料发现,是计算mse这步出错。


原因分析:

直面意思就是 类型应该是float而不是double


解决方案:

由于我们在利用torch计算mseloss时的张量类型有误而导致的,因此,在计算mse的时候,往里面加入tensor.float()即可

原来的代码如下:

    loss_fn = torch.nn
;