科技 > 人工智能 > 神经网络

pytorch写一个神经网络训练示例代码

112人参与 2024-08-05 神经网络

pytorch写一个神经网络训练示例代码

当使用pytorch训练全连接神经网络时,你需要定义神经网络模型、损失函数和优化器,并编写训练循环。下面是一个简单的示例,演示了如何使用pytorch来训练一个全连接神经网络。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

# 定义全连接神经网络模型
class simplenet(nn.module):
    def __init__(self, input_size, hidden_size, output_size):
        super(simplenet, self).__init__()
        self.fc1 = nn.linear(input_size, hidden_size)
        self.fc2 = nn.linear(hidden_size, output_size)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 准备数据
features = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=torch.float)
labels = torch.tensor([0, 1, 0], dtype=torch.long)

# 创建自定义的dataset
custom_dataset = data.tensordataset(features, labels)

# 创建dataloader
dataloader = data.dataloader(custom_dataset, batch_size=2, shuffle=true)

# 初始化模型、损失函数和优化器
input_size = 3
hidden_size = 5
output_size = 2
model = simplenet(input_size, hidden_size, output_size)
criterion = nn.crossentropyloss()
optimizer = optim.sgd(model.parameters(), lr=0.1)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f'epoch [{epoch+1}/{num_epochs}], loss: {loss.item():.4f}')

# 使用训练好的模型进行预测
test_input = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float)
with torch.no_grad():
    output = model(test_input)
    _, predicted = torch.max(output, 1)
    print('predicted class:', predicted.item())

在这个示例中,我们首先定义了一个简单的全连接神经网络模型simplenet,然后准备了数据并创建了dataset和dataloader。接下来,我们初始化了模型、损失函数和优化器,并进行了训练循环。最后,我们使用训练好的模型进行了一个简单的预测。

你可以根据自己的数据和模型结构进行相应的修改,希望这个示例能够帮助你开始使用pytorch训练全连接神经网络。

(0)
打赏 微信扫一扫 微信扫一扫

您想发表意见!!点此发布评论

推荐阅读

使用自己的数据利用pytorch搭建全连接神经网络进行回归预测

08-05

【NLP基础知识五】文本分类之神经网络文本分类、多标签分类

08-05

高斯分布的神经网络应用

08-05

神经网络应用场景——图像识别

08-05

深入探究深度学习、神经网络与卷积神经网络以及它们在多个领域中的应用

08-05

深度学习(6)--Keras项目详解(传统神经网络)

08-05

猜你喜欢

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论