import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np import argparse import os # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文 plt.rcParams['axes.unicode_minus'] = False # 正常显示负号 # CNN图像识别模型定义 class SimpleCNN(nn.Module): def __init__(self, num_classes=10): super(SimpleCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(128 * 8 * 8, 512), nn.ReLU(), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x # 数据预处理 def get_dataloaders(data_path, batch_size=64): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False) return trainloader, testloader # 测试函数 def test_cnn(model, testloader, device): model.eval() # 设置为评估模式 correct = 0 total = 0 with torch.no_grad(): # 不需要计算梯度 for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) # 获取预测结果 total += labels.size(0) correct += (predicted == labels).sum().item() # 计算正确预测数 accuracy = 100 * correct / total # 计算准确率 print(f'测试准确率: {accuracy:.2f}%') # 打印准确率 return accuracy # 展示预测结果函数 def show_sample_predictions(model, testloader, device, num_samples=6): model.eval() data_iter = iter(testloader) images, labels = next(data_iter) images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) # 转换回CPU并显示图像 images = images.cpu() labels = labels.cpu() predicted = predicted.cpu() # 创建图像展示 fig, axes = plt.subplots(1, num_samples, figsize=(15, 3)) for i in range(num_samples): img = transforms.ToPILImage()(images[i]) axes[i].imshow(img) axes[i].set_title('实际: %d\n预测: %d' % (labels[i].item(), predicted[i].item())) axes[i].axis('off') plt.tight_layout() # 确保结果目录存在 os.makedirs('results/cnn', exist_ok=True) plt.savefig('results/cnn/sample_predictions.png') plt.close() # 训练函数 def train_cnn(): # 参数设置 parser = argparse.ArgumentParser(description='Train CNN') parser.add_argument('--data_path', type=str, default='./data', help='Dataset path') args = parser.parse_args() # 设备配置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据加载 trainloader, testloader = get_dataloaders(args.data_path) # 模型初始化 model = SimpleCNN().to(device) # 损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练循环 for epoch in range(10): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # 零梯度 optimizer.zero_grad() # 前向传播 + 反向传播 + 优化 outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss += loss.item() if i % 100 == 99: # 每100个小批量打印一次 print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}') running_loss = 0.0 # 保存模型 torch.save(model.state_dict(), f'cnn_model_epoch{epoch}.pth') # 测试模型 print('开始测试模型...') accuracy = test_cnn(model, testloader, device) # 展示样本预测结果 show_sample_predictions(model, testloader, device) # 保存测试结果 with open('results/cnn_accuracy.txt', 'w') as f: f.write(f'Test Accuracy: {accuracy:.2f}%') print('Finished Training') if __name__ == '__main__': train_cnn()