0人参与 • 2026-01-19 • Pycharm
深度学习作为机器学习的重要分支,已在计算机视觉、自然语言处理等领域取得了显著成果。pytorch是由facebook开源的深度学习框架,以其动态计算图和直观的api设计而广受欢迎。本文以经典的mnist手写数字数据集为例,展示如何利用pytorch框架构建并训练深度学习模型。
首先检查pytorch及相关库的版本,确保环境配置正确:
import torch import torchvision import torchaudio from torch import nn from torch.utils.data import dataloader from torchvision import datasets from torchvision.transforms import totensor from matplotlib import pyplot as plt print(torch.__version__) print(torchaudio.__version__) print(torchvision.__version__)

mnist数据集包含60,000个训练样本和10,000个测试样本,每个样本为28×28像素的灰度手写数字图像。
training_data = datasets.mnist(
root="data",
train=true,
download=true,
transform=totensor(),
)
test_data = datasets.mnist(
root="data",
train=false,
download=true,
transform=totensor(),
)
参数:
root:数据存储路径train:是否为训练集download:是否自动下载transform:数据预处理转换,totensor()将pil图像转换为张量并归一化到[0,1]我们可以查看数据集的样本分布:
print(len(training_data))
figure = plt.figure()
for i in range(9):
img, label = training_data[i + 59000]
figure.add_subplot(3, 3, i + 1)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()


使用dataloader实现数据的批量加载和随机打乱:
# 增加批次大小
train_dataloader = dataloader(training_data, batch_size=128) # 增大batch size
test_dataloader = dataloader(test_data, batch_size=128)
for x, y in test_dataloader:
print(f"shape of x[n,c,h,w]:{x.shape}")
print(f"shape of y:{y.shape} {y.dtype}")
break

根据可用硬件选择计算设备:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"using {device} device")

设计一个包含多个全连接层的深度神经网络:
class neuralnetwork(nn.module):
def __init__(self):
super().__init__()
self.a = 10
self.flatten = nn.flatten()
原始架构
self.hidden1 = nn.linear(28 * 28, 128)
self.hidden2 = nn.linear(128, 256)
self.out = nn.linear(256, 10)
def forward(self, x):
# 原始前向传播
x = self.flatten(x)
x = self.hidden1(x)
x = torch.sigmoid(x)
x = self.hidden2(x)
x = torch.sigmoid(x)
return x
model = neuralnetwork().to(device) print(model)

def train(dataloader, model, loss_fn, optimizer):
model.train()
batch_size_num = 1
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model.forward(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
if batch_size_num % 100 == 0:
print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
batch_size_num += 1
训练步骤:
model.train():设置为训练模式(启用dropout)optimizer.zero_grad():清空梯度loss.backward():反向传播计算梯度optimizer.step():更新模型参数def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model.forward(x)
test_loss = loss_fn(pred, y)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
a = (pred.argmax(1) == y)
b = (pred.argmax(1) == y).type(torch.float)
test_loss /= num_batches
correct /= size
print(f"test result:\n accuracy:{(100 * correct):.2f}%, avg loss: {test_loss}")
测试要点:
model.eval():设置为评估模式(禁用dropout)torch.no_grad():禁用梯度计算,节省内存pred.argmax(1):获取预测类别loss_fn = nn.crossentropyloss()
损失函数说明:
crossentropyloss,适用于多分类问题# 原始优化器 optimizer = torch.optim.sgd(model.parameters(), lr=0.01)
train(train_dataloader, model, loss_fn, optimizer) test(train_dataloader, model, loss_fn)

epochs = 10
for t in range(epochs):
print(f"epoch {t+1}\n----------------------")
train(train_dataloader, model, loss_fn, optimizer)
print("done!")
test(test_dataloader, model, loss_fn)

# 改进架构
self.hidden1 = nn.linear(28 * 28, 512) # 增加神经元
self.dropout1 = nn.dropout(0.2) # 添加dropout
self.hidden2 = nn.linear(512, 256)
self.dropout2 = nn.dropout(0.2) # 添加dropout
self.hidden3 = nn.linear(256, 128) # 增加一层
self.out = nn.linear(128, 10)
# 改进的前向传播
x = self.flatten(x)
x = self.hidden1(x)
x = torch.relu(x) # 使用relu替代sigmoid
x = self.dropout1(x) # 训练时随机丢弃
x = self.hidden2(x)
x = torch.relu(x) # 使用relu替代sigmoid
x = self.dropout2(x) # 训练时随机丢弃
x = self.hidden3(x)
x = torch.relu(x)
x = self.out(x)
# 改进优化器 optimizer = torch.optim.adam(model.parameters(), lr=0.001) # 降低学习率

到此这篇关于pytorch基于mnist的手写数字识别的文章就介绍到这了,更多相关pytorch mnist手写数字识别内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
您想发表意见!!点此发布评论
版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。
发表评论