9人参与 • 2025-07-24 • Python
直接print(dir(nn.module)),得到如下内容:
parameters()
for param in model.parameters(): print(param.shape)
named_parameters()
for name, param in model.named_parameters(): if 'weight' in name: print(name, param.shape)
children()
for child in model.children(): print(type(child))
modules()
for module in model.modules(): if isinstance(module, nn.conv2d): print(module.kernel_size)
train()
和 eval()
model.train() # 训练模式 model.eval() # 推理模式
training
true
为训练,false
为推理)。print(model.training) # 输出:true/false
state_dict()
ordereddict
)。torch.save(model.state_dict(), 'model.pth')
load_state_dict()
model.load_state_dict(torch.load('model.pth'))
to()
model.to('cuda') # 移动到gpu model.to(torch.float16) # 转换为半精度
cpu()
和 cuda()
model.cuda() # 等价于 model.to('cuda')
forward()
class mymodel(nn.module): def forward(self, x): return self.layer(x)
__call__()
forward()
,支持钩子函数)。output = model(x) # 等价于 output = model.forward(x)
zero_grad()
optimizer.zero_grad() # 等价于 model.zero_grad()
requires_grad_()
for param in model.parameters(): param.requires_grad = false # 冻结所有参数
extra_repr()
class mymodel(nn.module): def extra_repr(self): return f"hidden_size={self.hidden_size}"
dump_patches()
apply()
def init_weights(m): if isinstance(m, nn.conv2d): nn.init.kaiming_normal_(m.weight) model.apply(init_weights)
register_forward_hook()
日常使用中,最频繁的方法包括:
parameters()
, children()
, modules()
train()
, eval()
, zero_grad()
, forward()
state_dict()
, load_state_dict()
to()
, cuda()
, cpu()
其他方法根据具体需求选择使用,例如钩子函数用于高级调试,apply()
用于统一初始化。
nn.module
_parameters
, _modules
, _buffers
等。nn.sequential
nn.module
的子类,继承了所有基础功能。__getitem__
、append
)。功能类别 | nn.module | nn.sequential |
---|---|---|
模块构建 | 需要手动实现 forward 方法 | 自动按顺序执行子模块,无需定义 forward |
子模块访问 | 通过属性名(如 self.conv1 ) | 通过索引或命名(如 model[0] ) |
动态修改 | 需手动管理子模块 | 支持 append 、extend 、insert 等操作 |
适用场景 | 复杂网络结构(如resnet、u-net) | 简单顺序结构(如lenet卷积部分) |
# 模型参数与结构 ['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules'] # 模型状态 ['train', 'eval', 'training', 'zero_grad', 'requires_grad_'] # 设备与数据类型 ['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16'] # 保存与加载 ['state_dict', 'load_state_dict'] # 钩子机制 ['register_forward_hook', 'register_backward_hook']
# 列表操作(动态修改模块顺序) ['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop'] # 索引相关 ['_get_item_by_idx']
# 自定义实现 ['forward', 'extra_repr'] # 高级管理 ['add_module', 'register_module', 'register_parameter', 'register_buffer']
# nn.module(需自定义 forward) class custommodel(nn.module): def __init__(self): super().__init__() self.conv = nn.conv2d(3, 64, 3) self.relu = nn.relu() def forward(self, x): return self.relu(self.conv(x)) # nn.sequential(自动按顺序执行) seq_model = nn.sequential( nn.conv2d(3, 64, 3), nn.relu() )
# nn.module custom_model.conv # 通过属性名访问 # nn.sequential seq_model[0] # 通过索引访问 seq_model.append(nn.maxpool2d(2)) # 动态添加模块
特性 | nn.module | nn.sequential |
---|---|---|
灵活性 | 高(自定义任意逻辑) | 低(仅支持顺序执行) |
代码复杂度 | 较高(需手动实现 forward ) | 低(自动处理前向传播) |
动态修改 | 不支持直接操作(需手动管理) | 支持 append 、insert 等操作 |
适用场景 | 复杂网络、分支结构、自定义操作 | 简单堆叠模块(如cnn的卷积部分) |
建议:
nn.sequential
以减少代码量。nn.module
自定义实现。到此这篇关于pytorch中nn.module详解的文章就介绍到这了,更多相关pytorch nn.module内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
您想发表意见!!点此发布评论
版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。
发表评论