def rename_layers(self, module, prefix=''):
for name, child in list(module.named_children()):
# 如果需要添加前缀,则修改名称
if prefix:
new_name = prefix + name
# 在这里重命名子模块
setattr(module, new_name, child)
# 删除旧的名称对应的子模块
delattr(module, name)
def rename_layers_back(self, module, prefix=''):
for name, child in list(module.named_children()):
# 如果需要添加前缀,则修改名称
if prefix:
new_name = name.replace(prefix, "")
# 在这里重命名子模块
setattr(module, new_name, child)
# 删除旧的名称对应的子模块
delattr(module, name)
# 更新当前网络的结构字典
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=False)
strict = False # 部分参数名称不一致时可忽略,不会报错