logo

使用PyTorch实现CIFAR-10图像分类:完整代码与深度解析

作者:c4t2025.10.12 00:42浏览量:162

简介:本文提供基于PyTorch的CIFAR-10图像分类完整实现,包含数据加载、模型构建、训练与评估全流程代码,每行代码均附详细注释,适合初学者快速掌握深度学习图像分类技术。

使用PyTorch实现CIFAR-10图像分类:完整代码与深度解析

一、引言

图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,以其动态计算图和简洁API深受研究者青睐。本文将通过CIFAR-10数据集(包含10类32x32彩色图像)的分类任务,系统演示如何使用PyTorch实现完整的图像分类流程。内容涵盖数据预处理、模型构建、训练循环、评估指标等关键环节,所有代码均附详细注释。

二、环境准备

首先需要安装必要的Python库:

  1. pip install torch torchvision matplotlib numpy
  • torch: PyTorch核心库
  • torchvision: 提供计算机视觉相关工具和数据集
  • matplotlib: 用于可视化训练过程
  • numpy: 数值计算基础库

三、完整代码实现

1. 数据加载与预处理

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. # 定义数据预处理流程
  5. transform = transforms.Compose([
  6. transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
  7. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
  8. ])
  9. # 加载训练集
  10. trainset = torchvision.datasets.CIFAR10(
  11. root='./data',
  12. train=True,
  13. download=True,
  14. transform=transform
  15. )
  16. trainloader = torch.utils.data.DataLoader(
  17. trainset,
  18. batch_size=32, # 每批加载32个样本
  19. shuffle=True, # 打乱数据顺序
  20. num_workers=2 # 使用2个子进程加载数据
  21. )
  22. # 加载测试集
  23. testset = torchvision.datasets.CIFAR10(
  24. root='./data',
  25. train=False,
  26. download=True,
  27. transform=transform
  28. )
  29. testloader = torch.utils.data.DataLoader(
  30. testset,
  31. batch_size=32,
  32. shuffle=False,
  33. num_workers=2
  34. )
  35. # 类别名称
  36. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  37. 'dog', 'frog', 'horse', 'ship', 'truck')

关键点说明

  • transforms.Compose将多个预处理操作组合
  • Normalize使用均值(0.5,0.5,0.5)和标准差(0.5,0.5,0.5)将像素值从[0,1]映射到[-1,1]
  • DataLoadershuffle=True确保训练时每个epoch数据顺序不同

2. 模型定义

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class CNN(nn.Module):
  4. def __init__(self):
  5. super(CNN, self).__init__()
  6. # 卷积层1:输入通道3(RGB),输出通道32,3x3卷积核
  7. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  8. # 卷积层2:输入通道32,输出通道64,3x3卷积核
  9. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  10. # 最大池化层:2x2窗口,步长2
  11. self.pool = nn.MaxPool2d(2, 2)
  12. # 全连接层1:输入64*8*8(经过两次池化后尺寸),输出128
  13. self.fc1 = nn.Linear(64 * 8 * 8, 128)
  14. # 全连接层2:输入128,输出10(类别数)
  15. self.fc2 = nn.Linear(128, 10)
  16. # Dropout层:防止过拟合,训练时以0.2概率丢弃神经元
  17. self.dropout = nn.Dropout(0.2)
  18. def forward(self, x):
  19. # 第一次卷积+ReLU激活+池化
  20. x = self.pool(F.relu(self.conv1(x))) # [batch,32,16,16]
  21. # 第二次卷积+ReLU激活+池化
  22. x = self.pool(F.relu(self.conv2(x))) # [batch,64,8,8]
  23. # 展平特征图
  24. x = x.view(-1, 64 * 8 * 8) # [batch,4096]
  25. # 全连接层+ReLU激活+Dropout
  26. x = self.dropout(F.relu(self.fc1(x))) # [batch,128]
  27. # 输出层
  28. x = self.fc2(x) # [batch,10]
  29. return x
  30. # 初始化模型
  31. model = CNN()

模型架构解析

  1. 输入尺寸:3x32x32(CIFAR-10图像尺寸)
  2. 特征提取:
    • 两次卷积(32→64通道)+ReLU激活+2x2最大池化
    • 每次池化后尺寸减半(32→16→8)
  3. 分类器:
    • 展平后接128维全连接层
    • Dropout防止过拟合
    • 最终输出10维logits

3. 训练配置

  1. import torch.optim as optim
  2. # 定义损失函数和优化器
  3. criterion = nn.CrossEntropyLoss() # 交叉熵损失
  4. optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
  5. # 设备配置(使用GPU如果可用)
  6. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  7. model.to(device)

优化器选择

  • Adam结合了动量法和RMSProp的优点,适合大多数场景
  • 学习率0.001是经验值,可根据训练情况调整

4. 训练循环

  1. def train_model(model, trainloader, criterion, optimizer, epochs=10):
  2. for epoch in range(epochs):
  3. running_loss = 0.0
  4. correct = 0
  5. total = 0
  6. for i, data in enumerate(trainloader, 0):
  7. inputs, labels = data[0].to(device), data[1].to(device)
  8. # 梯度清零
  9. optimizer.zero_grad()
  10. # 前向传播
  11. outputs = model(inputs)
  12. loss = criterion(outputs, labels)
  13. # 反向传播+优化
  14. loss.backward()
  15. optimizer.step()
  16. # 统计信息
  17. running_loss += loss.item()
  18. _, predicted = torch.max(outputs.data, 1)
  19. total += labels.size(0)
  20. correct += (predicted == labels).sum().item()
  21. # 每200个batch打印一次统计
  22. if i % 200 == 199:
  23. print(f'Epoch {epoch+1}, Batch {i+1}, '
  24. f'Loss: {running_loss/200:.3f}, '
  25. f'Acc: {100*correct/total:.2f}%')
  26. running_loss = 0.0
  27. # 每个epoch结束后打印验证准确率
  28. val_acc = validate_model(model, testloader)
  29. print(f'Epoch {epoch+1} completed. Val Acc: {val_acc:.2f}%')
  30. def validate_model(model, testloader):
  31. correct = 0
  32. total = 0
  33. with torch.no_grad(): # 禁用梯度计算
  34. for data in testloader:
  35. images, labels = data[0].to(device), data[1].to(device)
  36. outputs = model(images)
  37. _, predicted = torch.max(outputs.data, 1)
  38. total += labels.size(0)
  39. correct += (predicted == labels).sum().item()
  40. return 100 * correct / total
  41. # 启动训练
  42. train_model(model, trainloader, criterion, optimizer, epochs=10)

训练技巧

  1. 每个batch前调用optimizer.zero_grad()清除累积梯度
  2. 使用torch.no_grad()上下文管理器进行验证,节省内存
  3. 定期打印训练损失和准确率,监控训练过程

5. 模型评估与可视化

  1. import matplotlib.pyplot as plt
  2. # 可视化部分测试样本
  3. def imshow(img):
  4. img = img / 2 + 0.5 # 反归一化
  5. npimg = img.numpy()
  6. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  7. plt.show()
  8. # 获取一个batch的测试数据
  9. dataiter = iter(testloader)
  10. images, labels = next(dataiter)
  11. images, labels = images.to(device), labels.to(device)
  12. # 显示图像及预测结果
  13. outputs = model(images)
  14. _, predicted = torch.max(outputs, 1)
  15. imshow(torchvision.utils.make_grid(images[:4]))
  16. print('GroundTruth: ', ' '.join(f'{classes[labels[j]]}' for j in range(4)))
  17. print('Predicted: ', ' '.join(f'{classes[predicted[j]]}' for j in range(4)))
  18. # 最终评估
  19. final_acc = validate_model(model, testloader)
  20. print(f'Final Test Accuracy: {final_acc:.2f}%')

四、进阶优化建议

  1. 数据增强:在transform中添加随机裁剪、水平翻转等操作提升模型泛化能力

    1. transform_train = transforms.Compose([
    2. transforms.RandomHorizontalFlip(),
    3. transforms.RandomCrop(32, padding=4),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    6. ])
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率

    1. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用scheduler.step()
  3. 模型保存:保存最佳模型权重

    1. best_acc = 0.0
    2. for epoch in range(10):
    3. # ...训练代码...
    4. val_acc = validate_model(model, testloader)
    5. if val_acc > best_acc:
    6. best_acc = val_acc
    7. torch.save(model.state_dict(), 'best_model.pth')

五、总结

本文完整实现了基于PyTorch的CIFAR-10图像分类系统,包含:

  1. 数据加载与标准化预处理
  2. 自定义CNN模型构建(含卷积层、池化层、全连接层)
  3. 训练循环实现(含损失计算、反向传播、参数更新)
  4. 模型评估与可视化方法

通过实践,读者可以掌握PyTorch进行图像分类的核心流程,并理解各组件的作用。实际项目中,可进一步尝试更复杂的模型架构(如ResNet)、更大的数据集(如ImageNet)或分布式训练等高级技术。

相关文章推荐

发表评论

活动