it编程 > 游戏开发 > ar

PyTorch核心方法之state_dict()、parameters()参数打印与应用案例

48人参与 2025-12-14 ar

前言

本文以 lenet-5 模型为案例,介绍了 pytorch 中打印模型参数的相关方法。首先展示了 lenet-5 模型的结构定义及打印结果;随后详细说明了三种获取模型参数的方式:

模型案例

本文以lenet-5为基础模型,快速验证模型参数打印过程。

import os 
os.environ['cuda_visible_devices'] = '3'
import torch 
import torch.nn.functional as f 
import torch.nn as nn

class lenet5(nn.module):
    def __init__(self):
        super(lenet5, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.conv2d(1, 6, 5)
        self.conv2 = nn.conv2d(6, 16, 5)
        # an affine operation: y = wx + b
        self.fc1 = nn.linear(16 * 5 * 5, 120) # 这里论文上写的是conv,官方教程用了线性层
        self.fc2 = nn.linear(120, 84)
        self.fc3 = nn.linear(84, 10)

    def forward(self, x):
        # max pooling over a (2, 2) window
        x = f.max_pool2d(f.relu(self.conv1(x)), (2, 2))
        # if the size is a square you can only specify a single number
        x = f.max_pool2d(f.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = lenet5()
print(net)

模型结构打印如下。

a. state_dict()方法验证

在 pytorch 中,state_dict() 是核心方法之一,用于以有序字典(ordereddict)的形式返回模型 / 优化器等实例的可学习参数(或状态),是模型保存、加载、迁移学习的基础。

state_dict() 本质是一个 python 字典(pytorch 中为 ordereddict),键为参数 / 状态的名称(字符串),值为对应的张量(torch.tensor)。

print(type(net.state_dict()))   # <class 'collections.ordereddict'>
## 遍历打印
for model_key in net.state_dict():      # 【字典格式】的遍历,获取的是模型的名称
    print(f"{model_key}: {net.state_dict()[model_key].size()}")

对于lenet-5模型进行打印,可以看到state_dict()的类型为 <class 'collections.ordereddict'>,各层名称及参数尺寸如下图所示。

b. parameters()

parameters()方法也可以获取到模型的参数。可以看出,parameters()获取到的是一个生成器,其中仅包含各层参数的信息。

params = net.parameters()   
print(type(params))   # <class 'generator'>  生成器  

for param in params:    
    print(param.size())   # 只包含参数信息:具体的参数尺寸

对lenet-5进行模型参数打印。

如果也需要模型名称信息,可以使用named_parameters()方法。该方法获取的也是一个生成器,其中返回的是一个元组,包括模型名称和对应的参数。

named_params = net.named_parameters()   
print(type(named_params))   # <class 'generator'>  也是一个生成器

for name, param in named_params:
    print(f"{name}: {param.size()}")   # 同时获取网络名称和网络参数

对lenet-5进行模型名称及参数尺寸信息打印:

c. 模型结构冻结示例

该方法可以在对模型结构冻结时使用,如下述示例对模型结构m的参数进行冻结,同时打印确认冻结包含哪些网络结构。

# 示例
for name, param in m.named_parameters():
	param.requires_grad = false
	print(f"freezing layer {name}")

总结 

到此这篇关于pytorch核心方法之state_dict()、parameters()参数打印与应用案例的文章就介绍到这了,更多相关pytorch state_dict()、parameters()参数打印内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!

(0)

您想发表意见!!点此发布评论

推荐阅读

@NoArgsConstructor、@Getter、@Setter注解及Lombok的使用

12-11

@NoArgsConstructor注解

12-11

新增淘宝闪购等app! 华为鸿蒙 HarmonyOS 5/6实况窗支持应用更新

12-10

Prometheus+cpolar如何在手机上也能监控服务器状态?

12-08

接收文件可自动扫描病毒! 鸿蒙电脑 HarmonyOS 6.0.0.115 SP7版本推送更新

12-04

tomcat点击startup.bat一闪而过的解决全过程

12-03

猜你喜欢

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论