PyTorch 神经网络 —— MNIST 手写数字识别
目标
- 构建完整的训练/验证/测试 Pipeline
- 掌握
nn.Module、DataLoader、optimizer 三大组件
- 理解训练循环(forward → loss → backward → step)
- 使用 GPU 加速训练
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
# ============================================================
# 0. 配置
# ============================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 5
LR = 0.001
print(f"使用设备: {DEVICE}")
# ============================================================
# 1. 数据加载与预处理
# ============================================================
transform = transforms.Compose([
transforms.ToTensor(), # 0-255 → 0-1,HWC → CHW
transforms.Normalize((0.1307,), (0.3081,)) # MNIST 的均值和标准差
])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}")
# ============================================================
# 2. 定义模型
# ============================================================
class CNN(nn.Module):
"""简单的卷积神经网络"""
def __init__(self, num_classes=10):
super().__init__()
# 输入: (1, 28, 28)
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # → (16, 28, 28)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # → (32, 14, 14)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # → (64, 7, 7)
self.bn3 = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(2, 2) # 每次减半尺寸
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(64 * 3 * 3, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.pool(x) # (16, 14, 14)
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool(x) # (32, 7, 7)
x = F.relu(self.bn3(self.conv3(x)))
x = self.pool(x) # (64, 3, 3)
x = x.view(x.size(0), -1) # 展平
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# 实例化
model = CNN(num_classes=10).to(DEVICE)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# ============================================================
# 3. 损失函数与优化器
# ============================================================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
# ============================================================
# 4. 训练与评估函数
# ============================================================
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(loader, desc="训练", leave=False):
images, labels = images.to(device), labels.to(device)
# forward
outputs = model(images)
loss = criterion(outputs, labels)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_loss = running_loss / total
accuracy = correct / total
return avg_loss, accuracy
@torch.no_grad()
def evaluate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_loss = running_loss / total
accuracy = correct / total
return avg_loss, accuracy
# ============================================================
# 5. 训练循环
# ============================================================
history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
for epoch in range(1, EPOCHS + 1):
print(f"\n{'='*40}\nEpoch {epoch}/{EPOCHS}")
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
test_loss, test_acc = evaluate(model, test_loader, criterion, DEVICE)
scheduler.step()
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["test_loss"].append(test_loss)
history["test_acc"].append(test_acc)
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2%}")
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2%}")
print(f"\n✅ 训练完成!最终测试准确率: {test_acc:.2%}")
# ============================================================
# 6. 保存模型
# ============================================================
torch.save({
"epoch": EPOCHS,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"test_acc": test_acc,
}, "mnist_cnn.pth")
print("模型已保存为 mnist_cnn.pth")
# ============================================================
# 7. 预测单个样本
# ============================================================
model.eval()
sample, label = test_dataset[0]
with torch.no_grad():
output = model(sample.unsqueeze(0).to(DEVICE))
prob = F.softmax(output, dim=1)
pred = torch.argmax(prob, dim=1).item()
print(f"\n实际数字: {label}")
print(f"预测数字: {pred}")
print(f"各类概率: {prob.cpu().numpy().round(4)}")
预期输出(Epoch 5)
训练集大小: 60000, 测试集大小: 10000
模型参数量: 118,474
Epoch 5/5
Train Loss: 0.0123 | Train Acc: 99.52%
Test Loss: 0.0214 | Test Acc: 99.31%
✅ 训练完成!最终测试准确率: 99.31%
训练 Pipeline 图解
for epoch in range(EPOCHS):
for batch in DataLoader:
images, labels → to(device)
# ① 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# ② 反向传播
optimizer.zero_grad() # 清空旧梯度
loss.backward() # 计算新梯度
optimizer.step() # 更新参数
# ③ 学习率调整
scheduler.step()
# ④ 验证(torch.no_grad())
evaluate(model, test_loader)
关键要点
| 概念 |
说明 |
nn.Module |
所有神经网络层的基类,定义 forward() |
DataLoader |
自动批处理、打乱、多线程加载 |
transforms.Compose |
数据预处理流水线 |
optimizer.zero_grad() |
必须! 否则梯度会累积 |
model.train() / model.eval() |
切换 Dropout/BN 行为 |
torch.no_grad() |
推理时禁用梯度计算,节省内存 |
state_dict |
模型的参数字典,用于保存和加载 |
torch.save(obj, path) |
通用序列化保存(模型/字典/任意对象) |