Bootstrap

torch.load 报错 ModuleNotFoundError 或 AttributeError

Python 3.11.3 (main, Apr  7 2023, 19:25:52) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

正常情况下,我们会使用 torch.save 保存模型的 state_dict 。但我们也可以 torch.save 保存一个自定义类型对象,例如

import torch
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1
torch.save(Module(), 'module.pth')

在读取 module.pth 时,可能会遇到 AttributeError

import torch
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 3, in <module>
    torch.load('module.pth')
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'Module' on <module '__main__' from '/Users/bytedance/Developer/todd/load.py'>

这是因为 torch.save 底层通过 pickle 实现,而 pickle 在保存自定义类型对象时不会保存其类型定义。用户需要保证 torch.load 时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将 Module 引用到当前命名空间,就可以正常加载 module.pth

import torch
from save import Module
torch.load('module.pth')

但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法

import torch
class Module:
    def __init__(self) -> None:
        # in case __setstate__ is not called
        self._state = None
    def __setstate__(self, state):
        # whenever state is not empty, __setstate__ is called
        self._state = state
module = torch.load('module.pth')
print(module._state)
{'_one': 1}

但是如果自定义类型是从其他位置 import 得到的,例如

# module.py
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1

# save.py
import torch
from module import Module
torch.save(Module(), 'module.pth')

torch.load 会先尝试 import 相应的模块,如果不存在就会报错

Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 13, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'module'

我们可以 mock 相应模块

import sys
from unittest.mock import Mock
import torch
sys.modules['module'] = Mock()
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 14, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: NEWOBJ class argument must be a type, not Mock

出现这个问题,是因为 Mock 具有递归创建的特性。我们可以手动修改

import sys
from unittest.mock import Mock
import torch
class Module:
    def __init__(self) -> None:
        self._state = None
    def __setstate__(self, state):
        self._state = state
sys.modules['module'] = Mock()
sys.modules['module'].Module = Module
module = torch.load('module.pth')
print(module._state)
{'_one': 1}
;