Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

发布时间:2026/5/22 5:00:26

Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型 相关阅读Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482笔者在使用torch.nn.module的load_state_dict中出现了一个问题一个被注册的张量在加载后居然没有变化一开始以为是加载出现了问题但发现其他参数加载成功思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型恰好注册时的默认值给的是0而checkpoint中的浮点数又在0到1之间因此出现了这个令人困惑的bug。下面首先复现这个bug。import torch import torch.nn as nn # 定义一个简单的线性模型参数类型为整数 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.register_buffer(test, torch.tensor(0)) # 注册一个整型张量 # 创建一个简单模型实例 model SimpleModel() # 创建一个浮点数作为参数 float_parameter torch.tensor(0.6) # 将注册名指向另一个浮点型张量 model.test float_parameter # 保存模型 torch.save(model.state_dict(), model.pth) # 直接使用原模型加载 checkpoint torch.load(model.pth) model.load_state_dict(checkpoint) # 打印加载后的参数 print(model.test) # 直接使用新模型加载 model_1 SimpleModel() model_1.load_state_dict(checkpoint) # 打印加载后的参数 print(model_1.test)输出 tensor(0.6000) tensor(0)可以看到当模型中注册的名字(test)指向了一个类型不符的张量后并不会导致浮点型张量被截断为整型这是因为此处是直接使用赋值号使名字指向了另一个张量。但使用load_state_dict()方法与使用赋值号是不同的load_state_dict()方法的实现中调用了_load_from_state_dict()方法其中调用了copy_()方法进行了原位(in-place)数据替换这可能会进行截断下面是原位替换的一个例子。import torch # 创建两个张量 a torch.tensor([[1, 2], [3, 4]]) b torch.tensor([[5.1, 6.1], [7.1, 8.1]]) # 查看张量对象的id print(id(a)) print(id(b)) # 查看底层存储的内存地址 print(a.storage().data_ptr()) print(b.storage().data_ptr()) # 将张量 b 中的值复制到张量 a 中 a.copy_(b) # 打印复制后的结果 print(a) # 查看张量对象的id print(id(a)) print(id(b)) # 查看底层存储的内存地址 print(a.storage().data_ptr()) print(b.storage().data_ptr())输出 2604425272672 2604426953808 2604511348096 2602930352832 tensor([[5, 6], [7, 8]]) 2604425272672 2604426953808 2604511348096 2602930352832在保存了模型的状态字典后使用load_state_dict()方法加载后也不会有任何截断问题因为对于原模型而言名字test指向的是一个浮点型张量此时原位替换类型吻合。但是对于一个新的模型此时的test指向的是一个整型张量此时原位替换会发生截断。因此在注册一个张量时需要确保其在注册时和保存时的类型吻合此处除了指形状还有类型否则可能会出现意想不到的bug。

相关新闻