DeepQuantom-CNN/Modify.ipynb

927 lines
317 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:47.268865Z",
"start_time": "2025-06-25T06:49:25.453526Z"
}
},
"cell_type": "code",
"source": [
"# 首先我们导入所有需要的包:\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision.transforms as transforms\n",
"from tqdm import tqdm\n",
"from sklearn.metrics import roc_auc_score\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import FashionMNIST\n",
"import deepquantum as dq\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def seed_torch(seed=1024):\n",
" \"\"\"\n",
" Set random seeds for reproducibility.\n",
"\n",
" Args:\n",
" seed (int): Random seed number to use. Default is 1024.\n",
" \"\"\"\n",
"\n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed(seed)\n",
"\n",
" # Seed all GPUs with the same seed if using multi-GPU\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.backends.cudnn.benchmark = True\n",
" torch.backends.cudnn.deterministic = True\n",
"\n",
"seed_torch(42) # 使用更常见的随机种子值"
],
"id": "9cf2e2b5d8a6892d",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:47.819194Z",
"start_time": "2025-06-25T06:49:47.811528Z"
}
},
"cell_type": "code",
"source": [
"def calculate_score(y_true, y_preds):\n",
" # 将模型预测结果转为概率分布\n",
" preds_prob = torch.softmax(y_preds, dim=1)\n",
" # 获得预测的类别(概率最高的一类)\n",
" preds_class = torch.argmax(preds_prob, dim=1)\n",
" # 计算准确率\n",
" correct = (preds_class == y_true).float()\n",
" accuracy = correct.sum() / len(correct)\n",
" return accuracy.cpu().numpy()\n",
"\n",
"\n",
"def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device):\n",
" \"\"\"\n",
" 训练和验证模型。\n",
"\n",
" Args:\n",
" model (torch.nn.Module): 要训练的模型。\n",
" criterion (torch.nn.Module): 损失函数。\n",
" optimizer (torch.optim.Optimizer): 优化器。\n",
" train_loader (torch.utils.data.DataLoader): 训练数据加载器。\n",
" valid_loader (torch.utils.data.DataLoader): 验证数据加载器。\n",
" num_epochs (int): 训练的epoch数。\n",
"\n",
" Returns:\n",
" model (torch.nn.Module): 训练后的模型。\n",
" \"\"\"\n",
"\n",
" model.train()\n",
" train_loss_list = []\n",
" valid_loss_list = []\n",
" train_acc_list = []\n",
" valid_acc_list = []\n",
"\n",
" best_valid_acc = 0.0\n",
" patience = 10 # 早停耐心值\n",
" counter = 0 # 计数器\n",
"\n",
" scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10)\n",
"\n",
" with tqdm(total=num_epochs) as pbar:\n",
" for epoch in range(num_epochs):\n",
" # 训练阶段\n",
" train_loss = 0.0\n",
" train_acc = 0.0\n",
" for images, labels in train_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_acc += calculate_score(labels, outputs)\n",
"\n",
" train_loss /= len(train_loader)\n",
" train_acc /= len(train_loader)\n",
"\n",
" # 验证阶段\n",
" model.eval()\n",
" valid_loss = 0.0\n",
" valid_acc = 0.0\n",
" with torch.no_grad():\n",
" for images, labels in valid_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" valid_loss += loss.item()\n",
" valid_acc += calculate_score(labels, outputs)\n",
"\n",
" valid_loss /= len(valid_loader)\n",
" valid_acc /= len(valid_loader)\n",
"\n",
" # 学习率调度器更新\n",
" scheduler.step(valid_acc)\n",
"\n",
" # 早停机制\n",
" if valid_acc > best_valid_acc:\n",
" best_valid_acc = valid_acc\n",
" torch.save(model.state_dict(), './data/notebook2/best_model.pt')\n",
" counter = 0\n",
" else:\n",
" counter += 1\n",
"\n",
" if counter >= patience:\n",
" print(f'Early stopping at epoch {epoch+1} due to no improvement in validation accuracy.')\n",
" break\n",
"\n",
" pbar.set_description(f\"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}\")\n",
" pbar.update()\n",
"\n",
"\n",
" train_loss_list.append(train_loss)\n",
" valid_loss_list.append(valid_loss)\n",
" train_acc_list.append(train_acc)\n",
" valid_acc_list.append(valid_acc)\n",
"\n",
" # 加载最佳模型权重\n",
" if os.path.exists('./data/notebook2/best_model.pt'):\n",
" model.load_state_dict(torch.load('./data/notebook2/best_model.pt'))\n",
"\n",
" # 修改metrics构建方式确保各数组长度一致\n",
" metrics = {\n",
" 'epoch': list(range(1, len(train_loss_list) + 1)),\n",
" 'train_acc': train_acc_list,\n",
" 'valid_acc': valid_acc_list,\n",
" 'train_loss': train_loss_list,\n",
" 'valid_loss': valid_loss_list\n",
" }\n",
"\n",
"\n",
" return model, metrics\n",
"\n",
"def test_model(model, test_loader, device):\n",
" model.eval()\n",
" test_acc = 0.0\n",
" with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" test_acc += calculate_score(labels, outputs)\n",
"\n",
" test_acc /= len(test_loader)\n",
" print(f'Test Acc: {test_acc:.3f}')\n",
" return test_acc"
],
"id": "3a0bcd81cba9b9d4",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:47.839016Z",
"start_time": "2025-06-25T06:49:47.826008Z"
}
},
"cell_type": "code",
"source": [
"# 定义图像变换\n",
"trans1 = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(), # 随机水平翻转\n",
" transforms.RandomRotation(10), # 随机旋转±10度\n",
" transforms.ColorJitter(brightness=0.2, contrast=0.2), # 颜色调整\n",
" transforms.Resize((18, 18)), # 调整大小为18x18\n",
" transforms.ToTensor(), # 转换为张量\n",
" transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]\n",
"])\n",
"\n",
"trans2 = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(), # 随机水平翻转\n",
" transforms.RandomRotation(10), # 随机旋转±10度\n",
" transforms.ColorJitter(brightness=0.2, contrast=0.2), # 颜色调整\n",
" transforms.Resize((16, 16)), # 调整大小为16x16\n",
" transforms.ToTensor(), # 转换为张量\n",
" transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]\n",
"])\n",
"train_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=trans1,download=True)\n",
"test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=trans1,download=True)\n",
"\n",
"# 定义训练集和测试集的比例\n",
"train_ratio = 0.8 # 训练集比例为80%验证集比例为20%\n",
"valid_ratio = 0.2\n",
"total_samples = len(train_dataset)\n",
"train_size = int(train_ratio * total_samples)\n",
"valid_size = int(valid_ratio * total_samples)\n",
"\n",
"# 分割训练集和测试集\n",
"train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])\n",
"\n",
"# 加载随机抽取的训练数据集\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)"
],
"id": "7228ef013f3a1fdd",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:47.845266Z",
"start_time": "2025-06-25T06:49:47.841750Z"
}
},
"cell_type": "code",
"source": [
"singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3']\n",
"doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz']"
],
"id": "d9e483c48405660",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:47.895431Z",
"start_time": "2025-06-25T06:49:47.889872Z"
}
},
"cell_type": "code",
"source": [
"# 随机量子卷积层\n",
"class RandomQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self, nqubit, num_circuits, seed:int=1024):\n",
" super(RandomQuantumConvolutionalLayer, self).__init__()\n",
" random.seed(seed)\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
"\n",
" def circuit(self, nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作\n",
" cir.barrier()\n",
" for iter in range(3):\n",
" for i in range(nqubit):\n",
" singlegate = random.choice(singlegate_list)\n",
" getattr(cir, singlegate)(i)\n",
" control_bit, target_bit = random.sample(range(0, nqubit - 1), 2)\n",
" doublegate = random.choice(doublegate_list)\n",
" if doublegate[0] in ['r', 's']:\n",
" getattr(cir, doublegate)([control_bit, target_bit])\n",
" else:\n",
" getattr(cir, doublegate)(control_bit, target_bit)\n",
" cir.barrier()\n",
"\n",
" cir.observable(0)\n",
" return cir\n",
"\n",
" def forward(self, x):\n",
" kernel_size, stride = 2, 2\n",
" # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]\n",
" x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)\n",
" w = int((x.shape[-1] - kernel_size) / stride + 1)\n",
" x_reshape = x_unflod.reshape(-1, self.nqubit)\n",
"\n",
" exps = []\n",
" for cir in self.cirs: # out_channels\n",
" cir(x_reshape)\n",
" exp = cir.expectation()\n",
" exps.append(exp)\n",
"\n",
" exps = torch.stack(exps, dim=1)\n",
" exps = exps.reshape(x.shape[0], 3, w, w)\n",
" return exps"
],
"id": "f24b62bd70ab89eb",
"outputs": [],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:48.381594Z",
"start_time": "2025-06-25T06:49:47.897435Z"
}
},
"cell_type": "code",
"source": [
"net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024)\n",
"net.cirs[0].draw()"
],
"id": "f97107a549dee68",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1207.22x367.889 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7UAAAEvCAYAAACaO+Y5AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAZPBJREFUeJzt3XlYVOXbB/DvDAPDsMgmCoKIgoCQW+5b7qampqlZabbZXpaallqZabZopVm9mfbLzCXNPXNPS1xxS3EBEdn3HQaGgVneP8hJAhWQmcOc8/1cl1dxzjMz98C572fuOZvMaDQaQURERERERGSF5EIHQERERERERFRXbGqJiIiIiIjIarGpJSIiIiIiIqvFppaIiIiIiIisFptaIiIiIiIislpsaomIiIiIiMhqsaklIiIiIiIiq8WmloiIiIiIiKwWm1oiIiIiIiKyWmxqiYiIiIiIyGqxqSUiIiIiIiKrxaaWiIiIiIiIrBabWiIiIiIiIrJabGqJiIiIiIjIarGpJSIiIiIiIqvFppaIiIiIiIisFptaIiIiIiIislpsaomIiIiIiMhqsaklIiIiIiIiq8WmloiIiIiIiKwWm1oiIiIiIiKyWmxqiYiIiIiIyGqxqSUiIiIiIiKrxaaWiIiIiIiIrBabWiIiIiIiIrJabGqJiIiIiIjIarGpJSIiIiIiIqvFppaIiIiIiIislkLoAIjuxenTp2s1Pjs7G1u3bsUjjzyCxo0b1+gxXbp0qUtoRGQhtakDdakBAOsANWycC4lI6nMh99SSpGRnZ2PVqlXIzs4WOhQiEgBrABHzgEjqxFgD2NQSERERERGR1WJTS0RERERERFaLTS0RERERERFZLTa1JCnOzs4YOnQonJ2dhQ6FiATAGkDEPCCSOjHWAJnRaDQKHQRRXdX2io910ZCv9EZErANEzAEiknod4J5akhStVoukpCRotVqhQyEiAbAGEDEPiKROjDWATS1JSlxcHMaOHYu4uDihQyEiAbAGEDEPiKROjDVAIXQAVD2j0Qidxnq+PVGolJDJZEKHIXrWtl2IFbd3sgRrzHfmBjVU1phPltJQ8lasf6OG8vsVOza1DZROo8W6gElCh1FjE2PXwtbBXugwRM/atgux4vZOlmCN+c7coIbKGvPJUhpK3or1b9RQfr9ix8OPiYiIiIiIyGqxqSUiIiIiIiKrxcOPSVJCQkIQEREhdBhEJBDWACLmAZHUibEGcE8tERERERERWS02tSQpCQkJePbZZ5GQkCB0KEQkANYAIuYBkdSJsQawqSVJ0Wg0uHTpEjQajdChEJEAWAOImAdEUifGGsCmloiIiIiIiKwWLxQlIl49wjB06/xKy8qLNSi8kYbYzUdw9YfdMOoNAkVHUhT4aD/0XvYajr7xNa5v+rPKeidfT4w7/X+4vvEwjr75jeUDJBIZzgNEDQ/nQiLzY1MrQnHbjyHp4BlAJoPK0xWB4/ui6/yn4dLaBydmrhA6PCIiMjPOA0REJCVsakUo93IcbmwJN/0cvXofRocvRdATA3Hukw3Q5hQKGJ2wvL29MX/+fHh7ewsdChEJQCo1gPMA3YlU8oCIqifGGsBzaiVAp9Ei+9x1yORyNGrRVOhwBOXi4oJhw4bBxcVF6FCISABSrQGcB+hWUs0DIqogxhrAplYinP0rPsSU5hUJHImw8vLy8OuvvyIvL0/oUIhIAFKuAZwH6CYp5wERibMGsKkVIRuVEkp3Zyg9GsE1xA/dFk2BR9tWyDofg6K4dKHDE1RGRgYWL16MjIwMoUMhIgFIpQZwHqA7kUoeEFH1xFgDJHFObXZ2Nj755BNs27YNKSkp8PT0xKOPPoqFCxfi+eefx7p167By5UpMmTJF6FDrRYfp49Fh+vhKyxL2nMLJt1cKFBEREVkS5wEiIpIS0Te1Fy9exJAhQ5CRkQFHR0eEhoYiJSUFX3zxBeLj45GYmAgAaN++vcCR1p9r6w4ibscxyBU2cA32Q9vXR0PV2BW60jLTGLmdAiP3fYYb244i8qutpuW9l74Ke09XHJz4kRChk0QZjUahQyASFc4DRNaHc2HNtBzdC2Evj4Jra1/oNFqk/nURZxetRXFyttChkYBEffhxdnY2RowYgYyMDMyaNQvp6ek4d+4cMjIysGDBAmzduhV///035HI57rvvPqHDrTdF8elIC49EyuG/cfm7nfjjyY/R+P5A9PjkedMYQ5kO4VOXo93UMXALbQEA8BvaBb6DO+PY9G+FCp1E5uYHaBt7u2rX2zgoAQD6Wz5oE9G94zxA1HBwLqw/Ic8MRd//mwZ9aRlOf/ATrqz8Hc36tsPwnR9B1dRN6PBIQKJuaqdOnYqkpCRMnz4dn376KZycnEzr5s6di5CQEOh0OgQFBUGlUgkYqXllnYvBjS3haPVIH3je39q0PDcyDpe+2YE+X70OB2939PjsRZyaswqaDPGcNP5fDg4O6NatGxwcHIQORRLUiZkAANcg32rXu7auWF70zzgic5NqDeA8QLeSah4IhXNh/VC6OaHTnInIvhiLvY/MQ/Sa/bi4dAsOPPERHJq6oePMCUKHaDXEWANE29ReuXIFGzduRJMmTbBgwYIq62UyGTp16gSg6qHHcXFxGDVqFJydneHm5obJkycjJyfHInGby4UvN8Og06PjrMcqLb/41VYYynUYdWAx0o9fRtyOYwJFaBl+fn5Yvnw5/Pz8hA5FEnIib0CdkoWWo3tV+QZVbqtAm2eHwWgwIGn/GYEiJKmRcg3gPEA3STkPhMC5sH40f7ArbJ1UuLpqN4x6g2l5zoVYZJy8Cv9RPSG3Ff2ZlfVCjDVAtE3tunXrYDAYMGnSpNt+C6FUVhzucWtTW1RUhP79+yM5ORkbNmzA999/j/DwcIwYMQIGg6Ha57EGRfHpiNtxDM36tkeTbm1My416AzIiomDv4YLrGw8LGKFl6PV6qNVq6PV6oUORBKPegJNvr4StswMePvQFOs2dhKBJg9Bu2jiM3P8ZvHqGIfLr7SiMTRU6VJIIKdcAzgN0k5TzQAicC+uHZ8dAAEDmmegq6zLPRMPO2QEugT6WDssqibEGiLap/eOPPwAA/fv3v+2Y5ORkAJWb2u+//x4pKSnYvn07RowYgfHjx2PdunU4efIkdu7cad6gzezisi0w6PWVDs9o0iUYrR8fgKs/7EbXD5+57fkeYhETE4MBAwYgJiZG6FAkI/mPc9g96l2kH7uEwEf7ovuiKbjvxZEozSnEny98jnMfrxc6RJIQqdcAzgMEMA+EwLnw3jl4uQMAStJyq6wrTq04otLB292iMVkrMdYA0e6jT0hIAAD4+/tXu16r1eLUqVMAKje1u3btQu/evSvtju/Zsyf8/f3x22+/YfTo0bWOpXPnzkhPr919AW2NcsxD11o9Jv3EZaz2Hnfb9QUxKVjj++8HGYVKid7LXsO5ResR9dM+DNv2ITrNnYiI936s1esCQFDrIJTLLL8ne9y427/f6mRmVpyvsmfPHpw9e7ZGjxkzZkyt4zKXumwXDUHOhVj8+cLnQodRb4Ta3ql6takDdakBgDB1wNrmAYC5IRSpzYV10RDmz4Y6FzaUvL3b38hG9c8FtbTlVdbdXKb4Z0xDYqnfrxjmQi8vL5w5U7fD8EXb1BYXFwMASktLq12/du1aFBQUwMPDAz4+/x6qcOXKFYwfP77K+NDQUFy5cqVOsaSnpyMlJaVWj7GT2QBN6/RyNdZl/lNQJ2UhavVeAMDRN77GqINLkLg7AuknLtfquVLTUlFmtPwhDDf/zjWl0WhM/63pY2v7tzMnS2wXdHdCbe9UvdrUgbrUAECYOmBt8wDA3BCK1ObCuuD8eXsNJW/v9jfSa7QAABulbZUrRSv+OcJE98+YhsRSv1+xzoU1Jdqm1svLC0VFRThz5gw6d+5caV1SUhJmz54NAGjXrl2ldXl5eXB1da3yfG5ubrh+/XqdY6ktW6McMOOXOj79O6DlqF7YMXCGaVlRQgbOfrQOvZa+gh0DZkBXXP0XAtVp5t1MkG/5HB0dazX+ZuKqVKoaP/bWLz2EZu7tgmpGqO2dqle
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:49:48.391855Z",
"start_time": "2025-06-25T06:49:48.387742Z"
}
},
"cell_type": "code",
"source": [
"# 基于随机量子卷积层的混合模型\n",
"class RandomQCCNN(nn.Module):\n",
" def __init__(self):\n",
" super(RandomQCCNN, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核\n",
" nn.BatchNorm2d(3), # 添加批量归一化\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1),\n",
" nn.Conv2d(3, 6, kernel_size=2, stride=1),\n",
" nn.BatchNorm2d(6), # 添加批量归一化\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1)\n",
" )\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(6 * 6 * 6, 1024),\n",
" nn.BatchNorm1d(1024), # 添加批量归一化\n",
" nn.Dropout(0.5), # 增加dropout比例\n",
" nn.ReLU(),\n",
" nn.Linear(1024, 10)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = x.reshape(x.size(0), -1)\n",
" x = self.fc(x)\n",
" return x"
],
"id": "82be55cb39abc20a",
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:06.782447Z",
"start_time": "2025-06-25T06:49:48.398218Z"
}
},
"cell_type": "code",
"source": [
"# 修改RandomQCCNN模型的训练参数\n",
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"seed_torch(42) # 使用相同的随机种子值\n",
"model = RandomQCCNN()\n",
"model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5) # 使用AdamW优化器和适当的权重衰减\n",
"optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(optim_model.state_dict(), './data/notebook2/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook2/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "2087a1b2f259ad1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 0.556 Valid Acc: 0.760: 11%|█ | 33/300 [04:18<34:49, 7.83s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Early stopping at epoch 34 due to no improvement in validation accuracy.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:11.647094Z",
"start_time": "2025-06-25T06:54:06.811569Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook2/random_qccnn_weights.pt', map_location=device)\n",
"random_qccnn_model = RandomQCCNN()\n",
"random_qccnn_model.load_state_dict(state_dict)\n",
"random_qccnn_model.to(device)\n",
"\n",
"test_acc = test_model(random_qccnn_model, test_loader, device)"
],
"id": "b19364a05a067c50",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.797\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:11.888135Z",
"start_time": "2025-06-25T06:54:11.759275Z"
}
},
"cell_type": "code",
"source": [
"data = pd.read_csv('./data/notebook2/random_qccnn_metrics.csv')\n",
"epoch = data['epoch']\n",
"train_loss = data['train_loss']\n",
"valid_loss = data['valid_loss']\n",
"train_acc = data['train_acc']\n",
"valid_acc = data['valid_acc']\n",
"\n",
"# 创建图和Axes对象\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"# 绘制训练损失曲线\n",
"ax1.plot(epoch, train_loss, label='Train Loss')\n",
"ax1.plot(epoch, valid_loss, label='Valid Loss')\n",
"ax1.set_title('Training Loss Curve')\n",
"ax1.set_xlabel('Epoch')\n",
"ax1.set_ylabel('Loss')\n",
"ax1.legend()\n",
"\n",
"# 绘制训练准确率曲线\n",
"ax2.plot(epoch, train_acc, label='Train Accuracy')\n",
"ax2.plot(epoch, valid_acc, label='Valid Accuracy')\n",
"ax2.set_title('Training Accuracy Curve')\n",
"ax2.set_xlabel('Epoch')\n",
"ax2.set_ylabel('Accuracy')\n",
"ax2.legend()\n",
"\n",
"plt.show()"
],
"id": "752fc968a05aed09",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+kAAAHUCAYAAABGRmklAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA3+ZJREFUeJzs3Xd4FNXXwPHvphfSCQklBaQESAih9yIdQaoUpUkRRFRA3x9SBKwoiIAgIFIiKlVpShEQqQklhFCkQ0IoCS0kIb3N+8eQ1ZhCyiabcj7Ps09mZ+/MnE1Yds/ee8/VKIqiIIQQQgghhBBCCL0z0HcAQgghhBBCCCGEUEmSLoQQQgghhBBCFBOSpAshhBBCCCGEEMWEJOlCCCGEEEIIIUQxIUm6EEIIIYQQQghRTEiSLoQQQgghhBBCFBOSpAshhBBCCCGEEMWEJOlCCCGEEEIIIUQxIUm6EEIIIYQQQghRTEiSLso0jUaTq9vBgwcLdJ3Zs2ej0WjydezBgwd1EkNBrv3LL78U+bXz49y5c7z++utUrVoVMzMzypUrR4MGDZg7dy4RERH6Dk8IIUQuyHtz7u3YsQONRoODgwOJiYl6jaUkio6O5rPPPqNRo0ZYW1tjamqKu7s7I0eOJDAwUN/hiTLMSN8BCKFP/v7+Ge5/8skn/PXXXxw4cCDD/jp16hToOqNHj6Zr1675OrZBgwb4+/sXOIbS7vvvv2f8+PHUqlWL//u//6NOnTokJycTEBDA8uXL8ff3Z+vWrfoOUwghxHPIe3PurVq1CoCIiAi2bdvGwIED9RpPSXLjxg06d+7MgwcPGDduHB999BHlypUjJCSETZs20bBhQyIjI7GxsdF3qKIM0iiKoug7CCGKixEjRvDLL78QExOTY7u4uDgsLCyKKCr9OXjwIO3bt2fz5s30799f3+Fky9/fn9atW9OpUye2bduGqalphseTkpLYs2cPL7/8coGvFR8fj5mZWb57X4QQQuSNvDdnLTw8HBcXF9q0aYOfnx+tW7dm7969+g4rS8Xtb5OamoqPjw+3bt3i2LFjeHp6Zmqze/du2rZtW+C4FUUhISEBc3PzAp1HlC0y3F2I52jXrh2enp4cPnyYFi1aYGFhwciRIwHYuHEjnTt3pmLFipibm1O7dm0++OADYmNjM5wjqyF17u7u9OjRgz179tCgQQPMzc3x8PBg9erVGdplNaRuxIgRlCtXjuvXr9O9e3fKlSuHi4sL7733Xqbhbnfu3KF///5YWVlha2vLa6+9xqlTp9BoNPj6+urkd3ThwgV69eqFnZ0dZmZm1K9fnx9++CFDm7S0ND799FNq1aqFubk5tra21KtXj0WLFmnbPHz4kDfeeAMXFxdMTU1xdHSkZcuW7N+/P8frf/7552g0GlasWJEpQQcwMTHJkKBrNBpmz56dqZ27uzsjRozQ3vf19UWj0bB3715GjhyJo6MjFhYWbNy4EY1Gw59//pnpHMuWLUOj0XDu3DntvoCAAF5++WXs7e0xMzPDx8eHTZs25fichBBCZE/em+GHH34gJSWFSZMm0bdvX/78809u3bqVqV1kZCTvvfce1apVw9TUlAoVKtC9e3cuX76sbZOYmMjHH39M7dq1MTMzw8HBgfbt2+Pn5wdASEhItrH99z01/fcaGBhI//79sbOz44UXXgDU98NBgwbh7u6Oubk57u7uDB48OMu47969q/1MYGJiQqVKlejfvz/3798nJiYGW1tbxo4dm+m4kJAQDA0NmTdvXra/u23btnH+/HmmTp2aZYIO0K1bN22CPmLECNzd3TO1yerfkEajYcKECSxfvpzatWtjamrKypUrqVChAkOHDs10jsjISMzNzZk8ebJ2X3R0NO+//z5Vq1bFxMSEypUrM3HixEz/hkXpJcPdhciFsLAwhgwZwv/+9z8+//xzDAzU77euXbtG9+7dmThxIpaWlly+fJkvv/ySkydPZhqWl5WzZ8/y3nvv8cEHH+Dk5MTKlSsZNWoU1atXp02bNjkem5yczMsvv8yoUaN47733OHz4MJ988gk2NjbMnDkTgNjYWNq3b09ERARffvkl1atXZ8+ePTodDnflyhVatGhBhQoV+Oabb3BwcOCnn35ixIgR3L9/n//9738AzJ07l9mzZzNjxgzatGlDcnIyly9fJjIyUnuuoUOHEhgYyGeffUbNmjWJjIwkMDCQx48fZ3v91NRUDhw4QMOGDXFxcdHZ8/q3kSNH8tJLL/Hjjz8SGxtLjx49qFChAmvWrKFDhw4Z2vr6+tKgQQPq1asHwF9//UXXrl1p2rQpy5cvx8bGhg0bNjBw4EDi4uIyfCkghBAi98r6e/Pq1aupWLEi3bp1w9zcnHXr1uHr68usWbO0bZ4+fUqrVq0ICQlhypQpNG3alJiYGA4fPkxYWBgeHh6kpKTQrVs3jhw5wsSJE3nxxRdJSUnh+PHjhIaG0qJFizzFla5v374MGjSIcePGaZPLkJAQatWqxaBBg7C3tycsLIxly5bRuHFjLl68SPny5QE1QW/cuDHJyclMmzaNevXq8fjxY/744w+ePHmCk5MTI0eOZMWKFcydOzfDkPSlS5diYmKi/dImK+kjDnr37p2v5/Y827Zt48iRI8ycORNnZ2cqVKhAcHAwy5cv59tvv8Xa2lrbdv369SQkJPD6668D6qiDtm3bcufOHe1z//vvv5k5cybnz59n//79MpqvLFCEEFrDhw9XLC0tM+xr27atAih//vlnjsempaUpycnJyqFDhxRAOXv2rPaxWbNmKf99ubm5uSlmZmbKrVu3tPvi4+MVe3t7ZezYsdp9f/31lwIof/31V4Y4AWXTpk0Zztm9e3elVq1a2vvffvutAii7d+/O0G7s2LEKoKxZsybH55R+7c2bN2fbZtCgQYqpqakSGhqaYX+3bt0UCwsLJTIyUlEURenRo4dSv379HK9Xrlw5ZeLEiTm2+a/w8HAFUAYNGpTrYwBl1qxZmfa7ubkpw4cP195fs2aNAijDhg3L1Hby5MmKubm59vkpiqJcvHhRAZTFixdr93l4eCg+Pj5KcnJyhuN79OihVKxYUUlNTc113EIIURbJe3Nmhw8fVgDlgw8+0D7PqlWrKm5ubkpaWpq23ccff6wAyr59+7I919q1axVA+f7777NtExwcnG1s/31PTf+9zpw587nPIyUlRYmJiVEsLS2VRYsWafePHDlSMTY2Vi5evJjtsTdu3FAMDAyUBQsWaPfFx8crDg4Oyuuvv57jdbt27aoASkJCwnNjVBT1b+vm5pZpf1b/hgDFxsZGiYiIyLD/3LlzCqCsWLEiw/4mTZooDRs21N6fM2eOYmBgoJw6dSpDu19++UUBlF27duUqZlGyyXB3IXLBzs6OF198MdP+mzdv8uqrr+Ls7IyhoSHGxsa0bdsWgEuXLj33vPXr18fV1VV738zMjJo1a2Y57Ou/NBoNPXv2zLCvXr16GY49dOgQVlZWmQrjDB48+Lnnz60DBw7QoUOHTL3YI0aMIC4uTlsAqEmTJpw9e5bx48fzxx9/EB0dnelcTZo0wdfXl08//ZTjx4+TnJysszgLol+/fpn2jRw5kvj4eDZu3Kjdt2bNGkxNTXn11VcBuH79OpcvX+a1114DICUlRXvr3r07YWFhXLlypWiehBBClDJl+b05vWBcem+xRqNhxIgR3Lp1K8NUrN27d1OzZk06duyY7bl2796NmZlZjj3P+ZHVe2dMTAxTpkyhevXqGBkZYWRkRLly5YiNjc3wt9m9ezft27endu3a2Z6/WrVq9OjRg6VLl6I8K7G1bt06Hj9+zIQJE3T6XPLqxRdfxM7OLsM+Ly8vGjZsyJo1a7T7Ll26xMmTJzP87n///Xc8PT2pX79+hs8NXbp0KRYrCoiiIUm6ELlQsWLFTPtiYmJo3bo1J06c4NNPP+XgwYOcOnWKLVu2AGqBsedxcHDItM/U1DRXx1pYWGBmZpbp2ISEBO39x48f4+TklOnYrPbl1+PHj7P8/VSqVEn7OMDUqVP56quvOH78ON26dcPBwYEOHToQEBCgPWbjxo0MHz6
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:11.966429Z",
"start_time": "2025-06-25T06:54:11.961222Z"
}
},
"cell_type": "code",
"source": [
"class ParameterizedQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self, nqubit, num_circuits):\n",
" super().__init__()\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
"\n",
" def circuit(self, nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作\n",
" cir.barrier()\n",
" for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4可控参数一共16个\n",
" cir.rylayer()\n",
" cir.cnot_ring()\n",
" cir.barrier()\n",
"\n",
" cir.observable(0)\n",
" return cir\n",
"\n",
" def forward(self, x):\n",
" kernel_size, stride = 2, 2\n",
" # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]\n",
" x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)\n",
" w = int((x.shape[-1] - kernel_size) / stride + 1)\n",
" x_reshape = x_unflod.reshape(-1, self.nqubit)\n",
"\n",
" exps = []\n",
" for cir in self.cirs: # out_channels\n",
" cir(x_reshape)\n",
" exp = cir.expectation()\n",
" exps.append(exp)\n",
"\n",
" exps = torch.stack(exps, dim=1)\n",
" exps = exps.reshape(x.shape[0], 3, w, w)\n",
" return exps"
],
"id": "7694f4aa38f91ef2",
"outputs": [],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:12.225014Z",
"start_time": "2025-06-25T06:54:12.053468Z"
}
},
"cell_type": "code",
"source": [
"# 此处我们可视化其中一个量子卷积核的线路结构:\n",
"net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3)\n",
"net.cirs[0].draw()"
],
"id": "26eb9fed6938a56b",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 2210.55x785.944 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABroAAAJxCAYAAAAdC2LsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAyiJJREFUeJzs3Xd4FFXbBvB7N5veSQIJJRBqQui9IyC9qVQFLIgFfAEFQUVRKRZsqBQBwYIgIr13EEKH0AJJIEAI6SG9b7Ll+4PPKJLA7mZ3Znfm/l3Xe+m7O2fmiZnsvTPnzDkKvV6vBxEREREREREREREREZGNUYpdABEREREREREREREREZEp2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENokdXURERERERERERERERGST2NFFRERERERERERERERENkkldgFElXHu3Dmjtk9PT8fmzZvxzDPPwNfX16A2bdu2NaU0IiISiDFZYEoOAMwCIiJrxhwgIpI33hsiIiI+0UWykp6ejpUrVyI9PV3sUoiISATMASIieWMOEBERs4CISHrY0UVEREREREREREREREQ2iR1dREREREREREREREREZJPY0UVEREREREREREREREQ2iR1dJCvu7u7o168f3N3dxS6FiIhEwBwgIpI35gARETELiIikR6HX6/ViF0FkqnPnzln8GG3btrX4MYiIyHTMAiIieWMOEBHJG3OAiIj4RBfJilqtRnx8PNRqtdilEBGRCJgDRETyxhwgIiJmARGR9LCji2QlNjYWw4YNQ2xsrNilEBGRCJgDRETyxhwgIiJmARGR9KjELoDKp9froSmynZElKmdHKBQKscuQDFv7/QM8B4iIzM3WsoA5QERkXraWAwCzgIjInJgDRESGY0eXldIUqbG23lixyzDYmFtrYO/iJHYZkmFrv3+A5wARkbnZWhYwB4iIzMvWcgBgFhARmRNzgIjIcJy6kIiIiIiIiIiIiIiIiGwSO7qIiIiIiIiIiIiIiIjIJnHqQpKV4OBgnD17VuwyiIhIJMwBIiJ5Yw4QERGzgIhIevhEFxEREREREREREREREdkkdnSRrMTFxWH8+PGIi4sTuxQiIhIBc4CISN6YA0RExCwgIpIeTl1IslJUVISrV6+iqKhI7FKIiAQRl5SH89fSER6ZgZi7OShSa2GnVMDbwxEtGlVB68a+aBniAzcXe7FLFQRzgIjkJitX/f85kI6ImCzkFZRCr9fD1UWFkCAvtAn1RZtQP1TzcRa7VEEwB4hIbtQlWly+nonwyHRciEpHerYapRodHO3tUKeGG1qH+KJ1Yx80rOMJhUIhdrmCYBYQkZzodHrciMtBeOT9a4K4pAKoS7WwVynh6+WI1o190bqxL5o1rAJHBzuxyzUZO7qIiIgkpqhYgz/23sbS9VE4fy29wu1+/f9/OjvZ4dn+9fDGqBC0auwrTJFERGQxer0eh88kY+n6KGz7Kw5arf6xbZ7sUB1vjA7BoG6BUKk48QcRka27cScHP/wZhV+2xSA7r+Sx2zeq44mJI4PxwpAG8PJwFKBCIiKypKxcNX7ZFoMf/oxCTFxuhdut3HwDAODt4YCXnmqIiSNDUD/QQ6gyzYYdXRLi3zEU/TbPeeC10oIi5N5Oxq2NxxC1ajf0Wp1I1ZEQeA4QyZter8dvO25i2ldnkJGtNrhdUbEWP225gZ+23MCTHapjxYddEFTT3YKVkqUwB4goPDIdEz4Ow6XoTKPaHTydhIOnk1CnuhuWf9gZfTrVtFCFZGnMAiJ5u5dZhCmfn8Yfe28b1e76nRy8+cUZzPo+HB9PbIlpzzeBnR0HPtgi5gCRvGk0Onz1awTmLr+IomKtwe2yckvwzeqr+Gb1VYwZWA/fvdMBPl5OFqzUvNjRJUGxW08g/uB5QKGAs58X6o/ojnZzXoRngxo4NWO52OWRAHgOEMlPUloBXp17AruOxVdqPwdPJ6HpsM344q22mDgqRDbTl0gNc4BIfkpKtZi3/BI+W3XZoCe4KnInKR99X9+HV4Y1wtdvt4O7q4MZqyQhMQuI5GfzwTt4fd4J3MsqNnkfhcUazFx4DpsP3cHP87ohOMjLfAWSoJgDRPITeSsLL84+hnNXK57dxxBrd93CwdNJWP5hZwztUdtM1VkWh2ZIUOa1WNzeFIbbG4/h2g/bsWvgLOQn3kPD53rB0cf2Hjs0p4CAAMyZMwcBAQFil2JRPAeI5OV6bDbaj9lR6U6uvxUUafDGp6cwcf5JaCU20o85wBwgkqLCIg2GTjmI+SsuVaqT699+3HQd3V/ajbQMaa1fIpccAJgFRHLz2crLGDbtUKU6uf7t9JV7aD9mO45fSDHL/qyJXLKAOUAkL0fPJ6PD2B2V7uT6W2pGEZ6aehBf/nzFLPuzNHZ0yYCmSI30CzehUCrhUbua2OWIytPTE/3794enp6fYpQiK5wCRdN2Kz8UTL+9GQmqB2fe9fEM0Js4/Cb3ePDdNrQFzgDlAJDXqEi2GTj2AvScSzL7vi9EZePLVPcjKNXw6XGsn1xwAmAVEUvb5qsuY9f15s+83N78U/Sbuw+nLaWbft5jkmgXMASLpOn4hBf0n7UNeQanZ9z1z4Tl89UuE2fdrbuzokgn3OvcDrDgrT+RKxJWVlYUNGzYgKytL7FIEx3OASHqKijUY+MZ+pKRbbrT9j5uu45vVVy22f6ExB5gDRFIz5fNTOHg6yWL7j4jJwqgZhyUz6EHOOQAwC4ikaMuhO3jvO/N3cv2toEiDQZP3IyW90GLHEJqcs4A5QCQ9SWkFGDLlgFHrcRlrxjdnsfPoXYvt3xzY0SVBds6OcKziDkcfD3gFB6L9pxPg07Qu7l2MQV6s9B45N0Zqaiq+/PJLpKamil2KRfEcIJKHDxaH4/qdHKPanFs3BPEHRuPcuiHGHSc228jqrBNzgDlAJCX7TyZgxcbrRrUxJQcOnErCj5uMO461kksOAMwCIjlIzyrG6/NOGNXGlBzIyFZLaqYHuWQBc4BI+vR6PV6dewJZuSVGtTMlC+4fx3pnelCJXYAQ0tPT8fnnn2PLli1ITEyEn58fRo4cifnz5+OVV17B2rVr8eOPP2LChAlil2oWLaaNQItpIx54LW7PGZx+50eRKiKh8Rwgkr5Tl1Ox8Dfjn7Ty93VBzWquRrUpVmvx0odhOLF6EBQKhdHHJOExB4ikL7+wFBM+Pm50O1NyAACmf3UW/bvURC1/N6PbkjiYBUTSN3XBaaRlGrcml6k5sPVwHNbvvY3R/esZ3ZbEwRwgkr61u26ZtF67KVmQfK8Qby44jV8/6W708YQg+Y6uK1euoE+fPkhNTYWrqysaN26MxMREfPPNN7hz5w7u3r3/yF3z5s1FrtR8bqw9iNhtJ6BU2cGrUSCaTn4Kzr5e0BT/07OrdFBh8L4vcHvLcUR8v7ns9S7fvgEnPy8cHPOJGKWTmRhyDjyxYjr0ej2OvvZN2WsOXm546q+FOD93NW5vDhOjdLPT6/U4dzUdS9dHYVdYPHLySuDirELrEB9MHBmCoT1qw96eD7eS7fn0x8s
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T06:54:12.305720Z",
"start_time": "2025-06-25T06:54:12.301733Z"
}
},
"cell_type": "code",
"source": [
"# QCCNN整体网络架构\n",
"class QCCNN(nn.Module):\n",
" def __init__(self):\n",
" super(QCCNN, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3),\n",
" nn.BatchNorm2d(3), # 添加批量归一化\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1)\n",
" )\n",
"\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(8 * 8 * 3, 128),\n",
" nn.BatchNorm1d(128), # 添加批量归一化\n",
" nn.Dropout(0.5), # 增加dropout比例\n",
" nn.ReLU(),\n",
" nn.Linear(128, 10)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = x.reshape(x.size(0), -1)\n",
" x = self.fc(x)\n",
" return x"
],
"id": "ae9e76ee6bca6e2f",
"outputs": [],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:01:17.413980Z",
"start_time": "2025-06-25T06:54:12.401415Z"
}
},
"cell_type": "code",
"source": [
"# 修改QCCNN模型的训练参数\n",
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = QCCNN()\n",
"model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5) # 使用AdamW优化器和适当的权重衰减\n",
"optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(optim_model.state_dict(), './data/notebook2/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook2/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "81c62294cae7da16",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 0.531 Valid Acc: 0.784: 16%|█▌ | 48/300 [07:04<37:11, 8.85s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Early stopping at epoch 49 due to no improvement in validation accuracy.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:01:22.540624Z",
"start_time": "2025-06-25T07:01:17.490806Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook2/qccnn_weights.pt', map_location=device)\n",
"qccnn_model = QCCNN()\n",
"qccnn_model.load_state_dict(state_dict)\n",
"qccnn_model.to(device)\n",
"\n",
"test_acc = test_model(qccnn_model, test_loader, device)"
],
"id": "ffbeb8b34fdcdbc0",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.797\n"
]
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:01:22.666238Z",
"start_time": "2025-06-25T07:01:22.660994Z"
}
},
"cell_type": "code",
"source": [
"def vgg_block(in_channel,out_channel,num_convs):\n",
" layers = nn.ModuleList()\n",
" assert num_convs >= 1\n",
" layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1))\n",
" layers.append(nn.ReLU())\n",
" for _ in range(num_convs-1):\n",
" layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1))\n",
" layers.append(nn.ReLU())\n",
" layers.append(nn.MaxPool2d(kernel_size=2,stride=2))\n",
" return nn.Sequential(*layers)\n",
"\n",
"VGG = nn.Sequential(\n",
" vgg_block(1, 32, 2), # 增加通道数和调整卷积层数量\n",
" vgg_block(32, 64, 2),\n",
" nn.Flatten(),\n",
" nn.Linear(64 * 4 * 4, 256), # 调整全连接层大小\n",
" nn.BatchNorm1d(256), # 添加批量归一化\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5), # 增加dropout比例\n",
" nn.Linear(256, 128),\n",
" nn.BatchNorm1d(128), # 添加批量归一化\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(128, 10),\n",
" nn.Softmax(dim=-1)\n",
")"
],
"id": "f72e03c426bd658b",
"outputs": [],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:05:52.523565Z",
"start_time": "2025-06-25T07:01:22.782579Z"
}
},
"cell_type": "code",
"source": [
"# 修改VGG模型的训练参数\n",
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"vgg_model = VGG\n",
"vgg_model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.AdamW(vgg_model.parameters(), lr=3e-4, weight_decay=1e-5) # 使用AdamW优化器和适当的权重衰减\n",
"vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(vgg_model.state_dict(), './data/notebook2/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook2/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "234337eef155a6de",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 1.563 Valid Acc: 0.855: 28%|██▊ | 85/300 [04:29<11:22, 3.17s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Early stopping at epoch 86 due to no improvement in validation accuracy.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:05:55.080402Z",
"start_time": "2025-06-25T07:05:52.668397Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook2/vgg_weights.pt', map_location=device)\n",
"vgg_model = VGG\n",
"vgg_model.load_state_dict(state_dict)\n",
"vgg_model.to(device)\n",
"\n",
"vgg_test_acc = test_model(vgg_model, test_loader, device)"
],
"id": "ef857e4ec99a951a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.894\n"
]
}
],
"execution_count": 18
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:05:55.325058Z",
"start_time": "2025-06-25T07:05:55.198007Z"
}
},
"cell_type": "code",
"source": [
"vgg_data = pd.read_csv('./data/notebook2/vgg_metrics.csv')\n",
"qccnn_data = pd.read_csv('./data/notebook2/qccnn_metrics.csv')\n",
"vgg_epoch = vgg_data['epoch']\n",
"vgg_train_loss = vgg_data['train_loss']\n",
"vgg_valid_loss = vgg_data['valid_loss']\n",
"vgg_train_acc = vgg_data['train_acc']\n",
"vgg_valid_acc = vgg_data['valid_acc']\n",
"\n",
"qccnn_epoch = qccnn_data['epoch']\n",
"qccnn_train_loss = qccnn_data['train_loss']\n",
"qccnn_valid_loss = qccnn_data['valid_loss']\n",
"qccnn_train_acc = qccnn_data['train_acc']\n",
"qccnn_valid_acc = qccnn_data['valid_acc']\n",
"\n",
"# 创建图和Axes对象\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"# 绘制训练损失曲线\n",
"ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss')\n",
"ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss')\n",
"ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss')\n",
"ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss')\n",
"ax1.set_title('Training Loss Curve')\n",
"ax1.set_xlabel('Epoch')\n",
"ax1.set_ylabel('Loss')\n",
"ax1.legend()\n",
"\n",
"# 绘制训练准确率曲线\n",
"ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy')\n",
"ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy')\n",
"ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy')\n",
"ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy')\n",
"ax2.set_title('Training Accuracy Curve')\n",
"ax2.set_xlabel('Epoch')\n",
"ax2.set_ylabel('Accuracy')\n",
"ax2.legend()\n",
"\n",
"plt.show()"
],
"id": "5d20475f38028031",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+kAAAHUCAYAAABGRmklAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XV4FMcbwPHvxT0hBgFiQBKcBIK7E9xdAoEfWrQUKQVatHix4gkUpxQpGijurgGCRJBAIO52+/tjyUGIQoAg83mee7jbnd2dvRy39+7MvKOQJElCEARBEARBEARBEIQ8p5bXFRAEQRAEQRAEQRAEQSaCdEEQBEEQBEEQBEH4QoggXRAEQRAEQRAEQRC+ECJIFwRBEARBEARBEIQvhAjSBUEQBEEQBEEQBOELIYJ0QRAEQRAEQRAEQfhCiCBdEARBEARBEARBEL4QIkgXBEEQBEEQBEEQhC+ECNIFQRAEQRAEQRAE4QshgnThu6ZQKHL0OHbsWK6OM3nyZBQKxQdte+zYsY9Sh9wc+++///7sx/4QN27coHfv3tjb26Ojo4OBgQHly5dn1qxZhIaG5nX1BEEQhBwQ1+ac2717NwqFAjMzMxISEvK0Ll+jyMhIpk2bhqurK0ZGRmhra2NnZ0efPn24cuVKXldP+I5p5HUFBCEvnT17Ns3rKVOmcPToUY4cOZJmecmSJXN1nL59+9KkSZMP2rZ8+fKcPXs213X41q1cuZJBgwbh5OTE6NGjKVmyJElJSVy6dIlly5Zx9uxZduzYkdfVFARBELIhrs05t3r1agBCQ0PZuXMnnTp1ytP6fE0ePnxIo0aNCA4OZsCAAfz6668YGBjg7+/P1q1bqVChAuHh4RgbG+d1VYXvkEKSJCmvKyEIXwp3d3f+/vtvoqOjsywXGxuLnp7eZ6pV3jl27Bh169Zl27ZttG/fPq+rk6mzZ89Ss2ZNGjZsyM6dO9HW1k6zPjExkQMHDtCyZctcHysuLg4dHZ0Pbn0RBEEQ3o+4Nmfs+fPnWFtbU6tWLc6cOUPNmjXx9vbO62pl6Ev726SkpODi4kJAQACnT5+mdOnS6crs37+f2rVr57rekiQRHx+Prq5urvYjfF9Ed3dByEadOnUoXbo0J06coFq1aujp6dGnTx8AtmzZQqNGjbCyskJXV5cSJUowduxYYmJi0uwjoy51dnZ2NG/enAMHDlC+fHl0dXUpXrw4a9asSVMuoy517u7uGBgY8ODBA5o2bYqBgQHW1taMGjUqXXe3J0+e0L59ewwNDTExMaFbt25cvHgRhUKBl5fXR3mPbt26RatWrciXLx86Ojo4Ozuzdu3aNGWUSiVTp07FyckJXV1dTExMKFu2LH/88YeqzMuXL/nf//6HtbU12traWFhYUL16dQ4fPpzl8adPn45CoWDFihXpAnQALS2tNAG6QqFg8uTJ6crZ2dnh7u6ueu3l5YVCocDb25s+ffpgYWGBnp4eW7ZsQaFQ8N9//6Xbx59//olCoeDGjRuqZZcuXaJly5aYmpqio6ODi4sLW7duzfKcBEEQhMyJazOsXbuW5ORkRowYQdu2bfnvv/8ICAhIVy48PJxRo0ZRpEgRtLW1sbS0pGnTpty9e1dVJiEhgd9++40SJUqgo6ODmZkZdevW5cyZMwD4+/tnWrd3r6mp7+uVK1do3749+fLlo2jRooB8PezcuTN2dnbo6upiZ2dHly5dMqz306dPVb8JtLS0KFiwIO3bt+fFixdER0djYmJC//79023n7++Puro6s2fPzvS927lzJzdv3mTcuHEZBugAbm5uqgDd3d0dOzu7dGUy+gwpFAqGDBnCsmXLKFGiBNra2qxatQpLS0t69OiRbh/h4eHo6uoycuRI1bLIyEh+/PFH7O3t0dLSolChQgwfPjzdZ1j4donu7oKQA0FBQXTv3p2ffvqJ6dOno6Ym39+6f/8+TZs2Zfjw4ejr63P37l1+//13Lly4kK5bXkauX7/OqFGjGDt2LPnz52fVqlV4eHhQrFgxatWqleW2SUlJtGzZEg8PD0aNGsWJEyeYMmUKxsbGTJw4EYCYmBjq1q1LaGgov//+O8WKFePAgQMftTvcvXv3qFatGpaWlixcuBAzMzPWr1+Pu7s7L1684KeffgJg1qxZTJ48mQkTJlCrVi2SkpK4e/cu4eHhqn316NGDK1euMG3aNBwdHQkPD+fKlSuEhIRkevyUlBSOHDlChQoVsLa2/mjn9bY+ffrQrFkz/vrrL2JiYmjevDmWlpZ4enpSv379NGW9vLwoX748ZcuWBeDo0aM0adKEypUrs2zZMoyNjdm8eTOdOnUiNjY2zU0BQRAEIee+92vzmjVrsLKyws3NDV1dXTZu3IiXlxeTJk1SlYmKiqJGjRr4+/szZswYKleuTHR0NCdOnCAoKIjixYuTnJyMm5sbJ0+eZPjw4dSrV4/k5GTOnTtHYGAg1apVe696pWrbti2dO3dmwIABquDS398fJycnOnfujKmpKUFBQfz5559UrFgRHx8fzM3NATlAr1ixIklJSYwfP56yZcsSEhLCwYMHCQsLI3/+/PTp04cVK1Ywa9asNF3Sly5dipaWluqmTUZSexy0bt36g84tOzt37uTkyZNMnDiRAgUKYGlpiZ+fH8uWLWPJkiUYGRmpym7atIn4+Hh69+4NyL0OateuzZMnT1Tnfvv2bSZOnMjNmzc5fPiw6M33PZAEQVDp1auXpK+vn2ZZ7dq1JUD677//stxWqVRKSUlJ0vHjxyVAun79umrdpEmTpHf/u9na2ko6OjpSQECAallcXJxkamoq9e/fX7Xs6NGjEiAdPXo0TT0BaevWrWn22bRpU8nJyUn1esmSJRIg7d+/P025/v37S4Dk6emZ5TmlHnvbtm2ZluncubOkra0tBQYGplnu5uYm6enpSeHh4ZIkSVLz5s0lZ2fnLI9nYGAgDR8+PMsy73r+/LkESJ07d87xNoA0adKkdMttbW2lXr16qV57enpKgNSzZ890ZUeOHCnp6uqqzk+SJMnHx0cCpEWLFqmWFS9eXHJxcZGSkpLSbN+8eXPJyspKSklJyXG9BUEQvkfi2pzeiRMnJEAaO3as6jzt7e0lW1tbSalUqsr99ttvEiAdOnQo032tW7dOAqSVK1dmWsbPzy/Tur17TU19XydOnJjteSQnJ0vR0dGSvr6+9Mcff6iW9+nTR9LU1JR8fHwy3fbhw4eSmpqaNH/+fNWyuLg4yczMTOrdu3eWx23SpIkESPHx8dnWUZLkv62trW265Rl9hgDJ2NhYCg0NTbP8xo0bEiCtWLEizfJKlSpJFSpUUL2eMWOGpKamJl28eDFNub///lsCpH379uWozsLXTXR3F4QcyJcvH/Xq1Uu3/NGjR3Tt2pUCBQqgrq6OpqYmtWvXBuDOnTvZ7tfZ2RkbGxvVax0dHRwdHTPs9vUuhUJBixYt0iwrW7Zsmm2PHz+OoaFhusQ4Xbp0yXb/OXXkyBHq16+frhXb3d2d2NhYVQKgSpUqcf36dQYNGsTBgweJjIxMt69KlSrh5eXF1KlTOXfuHElJSR+tnrnRrl27dMv69OlDXFwcW7ZsUS3z9PREW1ubrl27AvDgwQPu3r1Lt27dAEhOTlY9mjZtSlBQEPfu3fs8JyEIgvCN+Z6vzakJ41JbixUKBe7u7gQEBKQZirV//34cHR1p0KBBpvvav38/Ojo6WbY8f4iMrp3R0dGMGTOGYsWKoaGhgYaGBgYGBsTExKT52+zfv5+6detSokSJTPdfpEgRmjdvztKlS5Fep9jauHEjISEhDBky5KOey/uqV68e+fLlS7OsTJkyVKhQAU9PT9WyO3fucOHChTTv/Z49eyhdujTOzs5pfjc0btz4i5hRQPg8RJAuCDlgZWWVbll0dDQ1a9bk/PnzTJ06lWPHjnHx4kX++ecfQE4wlh0zM7N0y7S1tXO0rZ6eHjo6Oum2jY+PV70OCQkhf/786bbNaNmHCgkJyfD9KViwoGo9wLhx45gzZw7nzp3Dzc0NMzMz6tevz6VLl1TbbNmyhV69erFq1SqqVq2KqakpPXv
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-25T07:05:55.477381Z",
"start_time": "2025-06-25T07:05:55.462043Z"
}
},
"cell_type": "code",
"source": [
"# 这里我们对比不同模型之间可训练参数量的区别\n",
"\n",
"def count_parameters(model):\n",
" \"\"\"\n",
" 计算模型的参数数量\n",
" \"\"\"\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"number_params_VGG = count_parameters(VGG)\n",
"number_params_QCCNN = count_parameters(QCCNN())\n",
"print(f'VGG 模型可训练参数量:{number_params_VGG}\\t QCCNN模型可训练参数量{number_params_QCCNN}')"
],
"id": "72451dcf013280ac",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"VGG 模型可训练参数量362346\t QCCNN模型可训练参数量26304\n"
]
}
],
"execution_count": 20
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}