163 lines
5.2 KiB
Python
163 lines
5.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import argparse
|
|
import os
|
|
|
|
# 设置中文字体
|
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文
|
|
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
|
|
|
|
# CNN图像识别模型定义
|
|
class SimpleCNN(nn.Module):
|
|
def __init__(self, num_classes=10):
|
|
super(SimpleCNN, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(128 * 8 * 8, 512),
|
|
nn.ReLU(),
|
|
nn.Linear(512, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
# 数据预处理
|
|
def get_dataloaders(data_path, batch_size=64):
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
trainset = torchvision.datasets.CIFAR10(root=data_path, train=True,
|
|
download=True, transform=transform)
|
|
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
|
|
|
testset = torchvision.datasets.CIFAR10(root=data_path, train=False,
|
|
download=True, transform=transform)
|
|
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
|
|
|
|
return trainloader, testloader
|
|
|
|
# 测试函数
|
|
def test_cnn(model, testloader, device):
|
|
model.eval() # 设置为评估模式
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad(): # 不需要计算梯度
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs.data, 1) # 获取预测结果
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item() # 计算正确预测数
|
|
|
|
accuracy = 100 * correct / total # 计算准确率
|
|
print(f'测试准确率: {accuracy:.2f}%') # 打印准确率
|
|
return accuracy
|
|
|
|
# 展示预测结果函数
|
|
def show_sample_predictions(model, testloader, device, num_samples=6):
|
|
model.eval()
|
|
data_iter = iter(testloader)
|
|
images, labels = next(data_iter)
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
# 转换回CPU并显示图像
|
|
images = images.cpu()
|
|
labels = labels.cpu()
|
|
predicted = predicted.cpu()
|
|
|
|
# 创建图像展示
|
|
fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
|
|
for i in range(num_samples):
|
|
img = transforms.ToPILImage()(images[i])
|
|
axes[i].imshow(img)
|
|
axes[i].set_title('实际: %d\n预测: %d' % (labels[i].item(), predicted[i].item()))
|
|
axes[i].axis('off')
|
|
|
|
plt.tight_layout()
|
|
# 确保结果目录存在
|
|
os.makedirs('results/cnn', exist_ok=True)
|
|
plt.savefig('results/cnn/sample_predictions.png')
|
|
plt.close()
|
|
|
|
# 训练函数
|
|
def train_cnn():
|
|
# 参数设置
|
|
parser = argparse.ArgumentParser(description='Train CNN')
|
|
parser.add_argument('--data_path', type=str, default='./data', help='Dataset path')
|
|
args = parser.parse_args()
|
|
|
|
# 设备配置
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# 数据加载
|
|
trainloader, testloader = get_dataloaders(args.data_path)
|
|
|
|
# 模型初始化
|
|
model = SimpleCNN().to(device)
|
|
|
|
# 损失函数和优化器
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
# 训练循环
|
|
for epoch in range(10): # loop over the dataset multiple times
|
|
running_loss = 0.0
|
|
for i, data in enumerate(trainloader, 0):
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# 零梯度
|
|
optimizer.zero_grad()
|
|
|
|
# 前向传播 + 反向传播 + 优化
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 打印统计信息
|
|
running_loss += loss.item()
|
|
if i % 100 == 99: # 每100个小批量打印一次
|
|
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
|
|
running_loss = 0.0
|
|
|
|
# 保存模型
|
|
torch.save(model.state_dict(), f'cnn_model_epoch{epoch}.pth')
|
|
|
|
# 测试模型
|
|
print('开始测试模型...')
|
|
accuracy = test_cnn(model, testloader, device)
|
|
|
|
# 展示样本预测结果
|
|
show_sample_predictions(model, testloader, device)
|
|
|
|
# 保存测试结果
|
|
with open('results/cnn_accuracy.txt', 'w') as f:
|
|
f.write(f'Test Accuracy: {accuracy:.2f}%')
|
|
|
|
print('Finished Training')
|
|
|
|
if __name__ == '__main__':
|
|
train_cnn() |