DeepQuantom-CNN/Origin.ipynb

856 lines
310 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-06-24T16:07:02.423934Z",
"start_time": "2025-06-24T16:07:02.418513Z"
}
},
"source": [
"# 首先我们导入所有需要的包:\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import deepquantum as dq\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision.transforms as transforms\n",
"from tqdm import tqdm\n",
"from sklearn.metrics import roc_auc_score\n",
"from torch.utils.data import DataLoader\n",
"# from torchvision.datasets import MNIST, FashionMNIST\n",
"\n",
"def seed_torch(seed=1024):\n",
" \"\"\"\n",
" Set random seeds for reproducibility.\n",
"\n",
" Args:\n",
" seed (int): Random seed number to use. Default is 1024.\n",
" \"\"\"\n",
"\n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed(seed)\n",
"\n",
" # Seed all GPUs with the same seed if using multi-GPU\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True\n",
"\n",
"seed_torch(1024)"
],
"outputs": [],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:02.531770Z",
"start_time": "2025-06-24T16:07:02.525115Z"
}
},
"cell_type": "code",
"source": [
"def calculate_score(y_true, y_preds):\n",
" # 将模型预测结果转为概率分布\n",
" preds_prob = torch.softmax(y_preds, dim=1)\n",
" # 获得预测的类别(概率最高的一类)\n",
" preds_class = torch.argmax(preds_prob, dim=1)\n",
" # 计算准确率\n",
" correct = (preds_class == y_true).float()\n",
" accuracy = correct.sum() / len(correct)\n",
" return accuracy.cpu().numpy()\n",
"\n",
"\n",
"def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device):\n",
" \"\"\"\n",
" 训练和验证模型。\n",
"\n",
" Args:\n",
" model (torch.nn.Module): 要训练的模型。\n",
" criterion (torch.nn.Module): 损失函数。\n",
" optimizer (torch.optim.Optimizer): 优化器。\n",
" train_loader (torch.utils.data.DataLoader): 训练数据加载器。\n",
" valid_loader (torch.utils.data.DataLoader): 验证数据加载器。\n",
" num_epochs (int): 训练的epoch数。\n",
"\n",
" Returns:\n",
" model (torch.nn.Module): 训练后的模型。\n",
" \"\"\"\n",
"\n",
" model.train()\n",
" train_loss_list = []\n",
" valid_loss_list = []\n",
" train_acc_list = []\n",
" valid_acc_list = []\n",
"\n",
" with tqdm(total=num_epochs) as pbar:\n",
" for epoch in range(num_epochs):\n",
" # 训练阶段\n",
" train_loss = 0.0\n",
" train_acc = 0.0\n",
" for images, labels in train_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_acc += calculate_score(labels, outputs)\n",
"\n",
" train_loss /= len(train_loader)\n",
" train_acc /= len(train_loader)\n",
"\n",
" # 验证阶段\n",
" model.eval()\n",
" valid_loss = 0.0\n",
" valid_acc = 0.0\n",
" with torch.no_grad():\n",
" for images, labels in valid_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" valid_loss += loss.item()\n",
" valid_acc += calculate_score(labels, outputs)\n",
"\n",
" valid_loss /= len(valid_loader)\n",
" valid_acc /= len(valid_loader)\n",
"\n",
" pbar.set_description(f\"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}\")\n",
" pbar.update()\n",
"\n",
"\n",
" train_loss_list.append(train_loss)\n",
" valid_loss_list.append(valid_loss)\n",
" train_acc_list.append(train_acc)\n",
" valid_acc_list.append(valid_acc)\n",
"\n",
" metrics = {'epoch': list(range(1, num_epochs + 1)),\n",
" 'train_acc': train_acc_list,\n",
" 'valid_acc': valid_acc_list,\n",
" 'train_loss': train_loss_list,\n",
" 'valid_loss': valid_loss_list}\n",
"\n",
"\n",
"\n",
" return model, metrics\n",
"\n",
"def test_model(model, test_loader, device):\n",
" model.eval()\n",
" test_acc = 0.0\n",
" with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" test_acc += calculate_score(labels, outputs)\n",
"\n",
" test_acc /= len(test_loader)\n",
" print(f'Test Acc: {test_acc:.3f}')\n",
" return test_acc"
],
"id": "cc4c2323375a0d64",
"outputs": [],
"execution_count": 24
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:02.643737Z",
"start_time": "2025-06-24T16:07:02.631472Z"
}
},
"cell_type": "code",
"source": [
"# 定义图像变换\n",
"trans1 = transforms.Compose([\n",
" transforms.Resize((18, 18)), # 调整大小为18x18\n",
" transforms.ToTensor() # 转换为张量\n",
"])\n",
"\n",
"trans2 = transforms.Compose([\n",
" transforms.Resize((16, 16)), # 调整大小为16x16\n",
" transforms.ToTensor() # 转换为张量\n",
"])\n",
"train_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)\n",
"test_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)\n",
"\n",
"# 定义训练集和测试集的比例\n",
"train_ratio = 0.8 # 训练集比例为80%验证集比例为20%\n",
"valid_ratio = 0.2\n",
"total_samples = len(train_dataset)\n",
"train_size = int(train_ratio * total_samples)\n",
"valid_size = int(valid_ratio * total_samples)\n",
"\n",
"# 分割训练集和测试集\n",
"train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])\n",
"\n",
"# 加载随机抽取的训练数据集\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)"
],
"id": "4b641527c641afc1",
"outputs": [],
"execution_count": 25
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:02.734411Z",
"start_time": "2025-06-24T16:07:02.731493Z"
}
},
"cell_type": "code",
"source": [
"singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3']\n",
"doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz']"
],
"id": "1c3e55f43e47a4f1",
"outputs": [],
"execution_count": 26
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:02.827116Z",
"start_time": "2025-06-24T16:07:02.822105Z"
}
},
"cell_type": "code",
"source": [
"# 随机量子卷积层\n",
"class RandomQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self, nqubit, num_circuits, seed:int=1024):\n",
" super(RandomQuantumConvolutionalLayer, self).__init__()\n",
" random.seed(seed)\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
"\n",
" def circuit(self, nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作\n",
" cir.barrier()\n",
" for iter in range(3):\n",
" for i in range(nqubit):\n",
" singlegate = random.choice(singlegate_list)\n",
" getattr(cir, singlegate)(i)\n",
" control_bit, target_bit = random.sample(range(0, nqubit - 1), 2)\n",
" doublegate = random.choice(doublegate_list)\n",
" if doublegate[0] in ['r', 's']:\n",
" getattr(cir, doublegate)([control_bit, target_bit])\n",
" else:\n",
" getattr(cir, doublegate)(control_bit, target_bit)\n",
" cir.barrier()\n",
"\n",
" cir.observable(0)\n",
" return cir\n",
"\n",
" def forward(self, x):\n",
" kernel_size, stride = 2, 2\n",
" # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]\n",
" x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)\n",
" w = int((x.shape[-1] - kernel_size) / stride + 1)\n",
" x_reshape = x_unflod.reshape(-1, self.nqubit)\n",
"\n",
" exps = []\n",
" for cir in self.cirs: # out_channels\n",
" cir(x_reshape)\n",
" exp = cir.expectation()\n",
" exps.append(exp)\n",
"\n",
" exps = torch.stack(exps, dim=1)\n",
" exps = exps.reshape(x.shape[0], 3, w, w)\n",
" return exps"
],
"id": "f03fcd820876a62",
"outputs": [],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:03.018614Z",
"start_time": "2025-06-24T16:07:02.915469Z"
}
},
"cell_type": "code",
"source": [
"net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024)\n",
"net.cirs[0].draw()"
],
"id": "fcea5aa513a0bd68",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1207.22x367.889 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7UAAAEvCAYAAACaO+Y5AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAZ0RJREFUeJzt3XlcVPX+x/HXDINsoiii4IK7oqTZdV/KpTL1mlppddOWX9m+a3rL6pZltqett8Vum2aWmVpZlmmJ+1aJOyoiIqAIsij7zO8PkiJQAWEOc877+Xj0SM45M/Me5nw/h8+czeZyuVyIiIiIiIiIeCC70QFEREREREREKktNrYiIiIiIiHgsNbUiIiIiIiLisdTUioiIiIiIiMdSUysiIiIiIiIeS02tiIiIiIiIeCw1tSIiIiIiIuKx1NSKiIiIiIiIx1JTKyIiIiIiIh5LTa2IiIiIiIh4LDW1IiIiIiIi4rHU1IqIiIiIiIjHUlMrIiIiIiIiHktNrYiIiIiIiHgsNbUiIiIiIiLisdTUioiIiIiIiMdSUysiIiIiIiIeS02tiIiIiIiIeCw1tSIiIiIiIuKx1NSKiIiIiIiIx1JTKyIiIiIiIh5LTa2IiIiIiIh4LDW1IiIiIiIi4rHU1IqIiIiIiIjHUlMrIiIiIiIiHktNrYiIiIiIiHgsNbUiIiIiIiLisdTUioiIiIiIiMdyGB1A5Fxs3LixQsunpKSwYMECrrzySho0aFCux3Tv3r0y0UTETSpSBypTA0B1QGo2bQtFxOrbQu2pFUtJSUlh1qxZpKSkGB1FRAygGiCicSBidWasAWpqRURERERExGOpqRURERERERGPpaZWREREREREPJaaWrGUwMBAhgwZQmBgoNFRRMQAqgEiGgciVmfGGmBzuVwuo0OIVFZFr/hYGTX5Sm8iojogojEgIlavA9pTK5aSm5tLfHw8ubm5RkcREQOoBohoHIhYnRlrgJpasZTY2FiuuuoqYmNjjY4iIgZQDRDROBCxOjPWAIfRAaRsLpeLgmzP+fbE4eeDzWYzOoZH8rTPWipGY0MqyxNrg9Z3qQqeuO5LSZWpBWb93FUX3UNNbQ1VkJ3LnNbjjI5RbmP3zcbb39foGB7J0z5rqRiNDaksT6wNWt+lKnjiui8lVaYWmPVzV110Dx1+LCIiIiIiIh5LTa2IiIiIiIh4LB1+LJYSERHBhg0bjI4hIgZRDRDROBCxOjPWAO2pFREREREREY+lplYsJS4ujptvvpm4uDijo4iIAVQDRDQORKzOjDVATa1YSnZ2Ntu2bSM7O9voKCJiANUAEY0DEaszYw1QUysiIiIiIiIeSxeKMpHQ3pEMWTC1xLT8E9lk7E9k3/yV7Hx/Ca5Cp0HpxIraXD2Afq/ew6r732Dv5z+Xml+7aQijN/6XvfNWsOqBN90fUMRktB0QqXm0LRSpfmpqTSh24Wril20Cmw2/kCDajOlPj6k3UbdtE9ZOesfoeCIiUs20HRAREStRU2tCqdtj2f9lVPHPuz9cyqiombS77mK2PDeX3GMZBqYzVlhYGFOnTiUsLMzoKCJiAKvUAG0H5EysMg5EpGxmrAE6p9YCCrJzSdmyF5vdTp3mjYyOY6i6desydOhQ6tata3QUETGAVWuAtgPyV1YdByJSxIw1QE2tRQS2KPojJict0+AkxkpLS+OLL74gLS3N6CgiYgAr1wBtB+QUK48DETFnDVBTa0Jefj741A/EJ7gOQRHh9Jw+nuBOrTj6awyZsUlGxzNUcnIyL774IsnJyUZHEREDWKUGaDsgZ2KVcSAiZTNjDbDEObUpKSk899xzfPXVVyQkJBASEsLVV1/NtGnTuPXWW5kzZw7vvfce48ePNzpqlegyYQxdJowpMS3uu/Ws+/d7BiUSERF30nZARESsxPRN7datWxk8eDDJyckEBATQsWNHEhISeOWVVzhw4AAHDx4E4Pzzzzc4adXZM2cZsYtWY3d4EdQ+nE73jsKvQRAFOXnFy9hrObh86Qvs/2oV0a8tKJ7eb+bd+IYEsWzsM0ZEF4tyuVxGRxAxFW0HRDyPtoXl03JUXyLvHEFQ26YUZOdy+JetbJ4+mxOHUoyOJgYy9eHHKSkpDB8+nOTkZCZPnkxSUhJbtmwhOTmZp59+mgULFvDbb79ht9s577zzjI5bZTIPJJEYFU3Cit/Y/vZifrr+WRr8ow29n7u1eBlnXgFR971O5/uuoF7H5gCED+lO00u7sXrCW0ZFF5M59Qe0l2+tMud7+fsAUPiXP7RF5NxpOyBSc2hbWHUi/m8I/f/7IIU5eWx88iN2vPctjft3ZtjiZ/BrVM/oeGIgUze19913H/Hx8UyYMIHnn3+e2rVrF8979NFHiYiIoKCggHbt2uHn52dg0up1dEsM+7+MotWVFxLyj7bF01OjY9n25iIufO1e/MPq0/uF21k/ZRbZyeY5afzv/P396dmzJ/7+/kZHsYSsg0cACGrXtMz5QW2Lpmf+sZxIdbNqDdB2QP7KquPAKNoWVg2ferXpOmUsKVv38f2VT7D74x/YOvNLfrzuGfwb1eOCSdcYHdFjmLEGmLap3bFjB/PmzaNhw4Y8/fTTpebbbDa6du0KlD70ODY2lhEjRhAYGEi9evW44YYbOHbsmFtyV5ffZ8zHWVDIBZOvLTF962sLcOYXMOLHF0las53YRasNSuge4eHhvP7664SHhxsdxRKORe8nK+EoLUf1LfUNqt3bQYebh+JyOon/YZNBCcVqrFwDtB2QU6w8DoygbWHVaHZZD7xr+7Fz1hJchc7i6cd+30fyup20GNEHu7fpz6ysEmasAaZtaufMmYPT6WTcuHGn/RbCx6focI+/NrWZmZkMHDiQQ4cOMXfuXN59912ioqIYPnw4TqezzOfxBJkHkohdtJrG/c+nYc8OxdNdhU6SN+zCN7gue+etMDChexQWFpKVlUVhYaHRUSzBVehk3b/fwzvQn5HLX6Hro+NoN+4SOj84mst/eIHQPpFEv7GQjH2HjY4qFmHlGqDtgJxi5XFgBG0Lq0bIBW0AOLJpd6l5RzbtplagP3XbNHF3LI9kxhpg2qb2p59+AmDgwIGnXebQoUNAyab23XffJSEhgYULFzJ8+HDGjBnDnDlzWLduHYsXL67e0NVs66tf4iwsLHF4RsPu7Wn7r0HsfH8JPZ76v9Oe72EWMTExDBo0iJiYGKOjWMahn7awZMRjJK3eRpur+9Nr+njOu/1yco5l8PNtL7Pl2U+NjigWYvUaoO2AgMaBEbQtPHf+ofUBOJmYWmreicNFR1T6h9V3ayZPZcYaYNp99HFxcQC0aNGizPm5ubmsX78eKNnUfvPNN/Tr16/E7vg+ffrQokULvv76a0aNGlXhLN26dSMpqWL3BfR22XmCHhV6TNLa7XwYNvq089NjEvi46Z9/yDj8fOj36j1smf4puz5aytCvnqLro2PZ8PgHFXpdgHZt25Fvc/+e7NGjT/9+y3LkSNH5Kt999x2bN28u12OuuOKKCueqiMp81p7m2O/7+Pm2l42OYQijxoaVVKQOVKYGQPXXgbJ42nYAtL4bxQzbwr8y63bRStvCytSCs33uXn5/XFArN7/UvFPTHH8sU5O4qy6aYVsYGhrKpk2VOwzftE3tiRMnAMjJySlz/uzZs0lPTyc4OJgmTf48VGHHjh2MGTOm1PIdO3Zkx44dlcqSlJREQkJChR5Ty+YFjSr1cuXWfeqNZMUfZdeH3wOw6v43GLHsJQ4u2UDS2u0Veq7DiYfJc7n/EIZTn3N5ZWdnF/+/vI+t6GdXUe74rMU4Ro0NK6lIHahMDYDqrwNl8bTtAGh9N4oZtoV/pe2i56tMLTjb516YnQuAl493qStFO/44wqTgj2VqEnfVRbNuC8vLtE1taGgomZmZbNq0iW7dupWYFx8fzyOPPAJA586dS8xLS0sjKCio1PPVq1ePvXv3VjpLRXm77FCNX+o0GdiFliP6sujiicXTMuOS2fzMHPrOvItFgyZScKLsLwTK0jissSHfzgcEBFRo+VMD18/Pr9yP/euXHtWhuj9rMZZRY8N
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:07:03.130980Z",
"start_time": "2025-06-24T16:07:03.126867Z"
}
},
"cell_type": "code",
"source": [
"# 基于随机量子卷积层的混合模型\n",
"class RandomQCCNN(nn.Module):\n",
" def __init__(self):\n",
" super(RandomQCCNN, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1),\n",
" nn.Conv2d(3, 6, kernel_size=2, stride=1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1)\n",
" )\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(6 * 6 * 6, 1024),\n",
" nn.Dropout(0.4),\n",
" nn.Linear(1024, 10)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = x.reshape(x.size(0), -1)\n",
" x = self.fc(x)\n",
" return x"
],
"id": "64082ff8ea82fe8",
"outputs": [],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:33.786208Z",
"start_time": "2025-06-24T16:07:03.216673Z"
}
},
"cell_type": "code",
"source": [
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"seed_torch(1024) # 重新设置随机种子\n",
"model = RandomQCCNN()\n",
"model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加正则化项\n",
"optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(optim_model.state_dict(), './data/notebook1/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook1/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "19b3021c114a9129",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 0.620 Valid Acc: 0.756: 100%|██████████| 300/300 [33:30<00:00, 6.70s/it]\n"
]
}
],
"execution_count": 30
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:38.222252Z",
"start_time": "2025-06-24T16:40:33.895998Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook1/random_qccnn_weights.pt', map_location=device)\n",
"random_qccnn_model = RandomQCCNN()\n",
"random_qccnn_model.load_state_dict(state_dict)\n",
"random_qccnn_model.to(device)\n",
"\n",
"test_acc = test_model(random_qccnn_model, test_loader, device)"
],
"id": "49ceb326295cd4a9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.769\n"
]
}
],
"execution_count": 31
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:38.644386Z",
"start_time": "2025-06-24T16:40:38.356492Z"
}
},
"cell_type": "code",
"source": [
"data = pd.read_csv('./data/notebook1/random_qccnn_metrics.csv')\n",
"epoch = data['epoch']\n",
"train_loss = data['train_loss']\n",
"valid_loss = data['valid_loss']\n",
"train_acc = data['train_acc']\n",
"valid_acc = data['valid_acc']\n",
"\n",
"# 创建图和Axes对象\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"# 绘制训练损失曲线\n",
"ax1.plot(epoch, train_loss, label='Train Loss')\n",
"ax1.plot(epoch, valid_loss, label='Valid Loss')\n",
"ax1.set_title('Training Loss Curve')\n",
"ax1.set_xlabel('Epoch')\n",
"ax1.set_ylabel('Loss')\n",
"ax1.legend()\n",
"\n",
"# 绘制训练准确率曲线\n",
"ax2.plot(epoch, train_acc, label='Train Accuracy')\n",
"ax2.plot(epoch, valid_acc, label='Valid Accuracy')\n",
"ax2.set_title('Training Accuracy Curve')\n",
"ax2.set_xlabel('Epoch')\n",
"ax2.set_ylabel('Accuracy')\n",
"ax2.legend()\n",
"\n",
"plt.show()"
],
"id": "45287356d5a9a0ad",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/IAAAHUCAYAAACZCBM6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA3OFJREFUeJzs3Xd4U+X7x/F3kjbdE2gpq+y995AlAoKiiPxA/SrTgYiKOHGAol9xi4qiKFD5qoiC4AIUlaUgyFbZs4yWVbpnmvP747QppQUKtA3j87quXE2ePOfkPqmS3rmfYTEMw0BERERERERELgtWdwcgIiIiIiIiIkWnRF5ERERERETkMqJEXkREREREROQyokReRERERERE5DKiRF5ERERERETkMqJEXkREREREROQyokReRERERERE5DKiRF5ERERERETkMqJEXkREREREROQyokRe5BwsFkuRbkuXLr2o13n++eexWCwXdOzSpUuLJYaLee05c+aU+mtfiM2bNzN06FCqVauGt7c3/v7+NG/enNdee424uDh3hyciIkWgz+ai++6777BYLJQpU4aMjAy3xnI5SkxM5L///S8tW7YkMDAQLy8vqlatyrBhw1i/fr27w5OrmIe7AxC51K1atSrf4xdffJElS5bw22+/5WuvX7/+Rb3O3XffzfXXX39BxzZv3pxVq1ZddAxXuo8//piRI0dSp04dHn/8cerXr09WVhZr167lww8/ZNWqVcybN8/dYYqIyDnos7nopk2bBkBcXBzz589n4MCBbo3ncrJ792569OjB0aNHGTFiBC+88AL+/v7s27ePr776ihYtWhAfH09QUJC7Q5WrkMUwDMPdQYhcToYMGcKcOXNITk4+a7/U1FR8fX1LKSr3Wbp0KV27duXrr7+mf//+7g7njFatWkXHjh3p3r078+fPx8vLK9/zmZmZLFq0iJtuuumiXystLQ1vb+8LruKIiMj50Wdz4WJjY6lcuTKdOnVi5cqVdOzYkZ9//tndYRXqUvvdZGdn06xZM/bv388ff/xBw4YNC/RZuHAhnTt3vui4DcMgPT0dHx+fizqPXF00tF6kGHTp0oWGDRuyfPly2rdvj6+vL8OGDQNg9uzZ9OjRg4iICHx8fKhXrx5PPfUUKSkp+c5R2PC9qlWrcuONN7Jo0SKaN2+Oj48PdevWZfr06fn6FTZ8b8iQIfj7+7Nr1y569+6Nv78/lStX5tFHHy0wtO7gwYP079+fgIAAgoOD+c9//sNff/2FxWIhKiqqWN6jf/75h5tvvpmQkBC8vb1p2rQpn376ab4+TqeTl156iTp16uDj40NwcDCNGzfmnXfecfU5duwY9957L5UrV8bLy4ty5crRoUMHfvnll7O+/ssvv4zFYmHq1KkFkngAu92eL4m3WCw8//zzBfpVrVqVIUOGuB5HRUVhsVj4+eefGTZsGOXKlcPX15fZs2djsVj49ddfC5xjypQpWCwWNm/e7Gpbu3YtN910E6GhoXh7e9OsWTO++uqrs16TiIicmT6b4dNPP8XhcPDII4/Qr18/fv31V/bv31+gX3x8PI8++ijVq1fHy8uLsLAwevfuzbZt21x9MjIymDBhAvXq1cPb25syZcrQtWtXVq5cCcC+ffvOGNvpn6m57+v69evp378/ISEh1KhRAzA/D2+77TaqVq2Kj48PVatW5fbbby807kOHDrn+JrDb7VSoUIH+/ftz5MgRkpOTCQ4O5r777itw3L59+7DZbLz++utnfO/mz5/P33//zdixYwtN4gF69erlSuKHDBlC1apVC/Qp7L8hi8XCqFGj+PDDD6lXrx5eXl588sknhIWFcddddxU4R3x8PD4+PowZM8bVlpiYyGOPPUa1atWw2+1UrFiR0aNHF/hvWK5cGlovUkxiYmK48847eeKJJ3j55ZexWs3vyXbu3Env3r0ZPXo0fn5+bNu2jVdffZU1a9YUGAJYmE2bNvHoo4/y1FNPER4ezieffMLw4cOpWbMmnTp1OuuxWVlZ3HTTTQwfPpxHH32U5cuX8+KLLxIUFMS4ceMASElJoWvXrsTFxfHqq69Ss2ZNFi1aVKxD77Zv30779u0JCwvj3XffpUyZMnz22WcMGTKEI0eO8MQTTwDw2muv8fzzz/Pss8/SqVMnsrKy2LZtG/Hx8a5z3XXXXaxfv57//ve/1K5dm/j4eNavX8+JEyfO+PrZ2dn89ttvtGjRgsqVKxfbdZ1q2LBh3HDDDfzvf/8jJSWFG2+8kbCwMGbMmEG3bt3y9Y2KiqJ58+Y0btwYgCVLlnD99dfTpk0bPvzwQ4KCgvjyyy8ZOHAgqamp+b44EBGRorvaP5unT59OREQEvXr1wsfHhy+++IKoqCjGjx/v6pOUlMQ111zDvn37ePLJJ2nTpg3JycksX76cmJgY6tati8PhoFevXqxYsYLRo0dz7bXX4nA4+PPPP4mOjqZ9+/bnFVeufv36cdtttzFixAhXArpv3z7q1KnDbbfdRmhoKDExMUyZMoVWrVqxZcsWypYtC5hJfKtWrcjKyuLpp5+mcePGnDhxgp9++omTJ08SHh7OsGHDmDp1Kq+99lq+4e8ffPABdrvd9cVOYXJHLvTt2/eCru1c5s+fz4oVKxg3bhzly5cnLCyMvXv38uGHH/L+++8TGBjo6jtr1izS09MZOnQoYI5e6Ny5MwcPHnRd+7///su4ceP4+++/+eWXXzQq8GpgiMh5GTx4sOHn55evrXPnzgZg/Prrr2c91ul0GllZWcayZcsMwNi0aZPrufHjxxun/y8ZGRlpeHt7G/v373e1paWlGaGhocZ9993naluyZIkBGEuWLMkXJ2B89dVX+c7Zu3dvo06dOq7H77//vgEYCxcuzNfvvvvuMwBjxowZZ72m3Nf++uuvz9jntttuM7y8vIzo6Oh87b169TJ8fX2N+Ph4wzAM48YbbzSaNm161tfz9/c3Ro8efdY+p4uNjTUA47bbbivyMYAxfvz4Au2RkZHG4MGDXY9nzJhhAMagQYMK9B0zZozh4+Pjuj7DMIwtW7YYgPHee++52urWrWs0a9bMyMrKynf8jTfeaERERBjZ2dlFjltE5Gqkz+aCli9fbgDGU0895brOatWqGZGRkYbT6XT1mzBhggEYixcvPuO5Zs6caQDGxx9/fMY+e/fuPWNsp3+m5r6v48aNO+d1OBwOIzk52fDz8zPeeecdV/uwYcMMT09PY8uWLWc8dvfu3YbVajXefvttV1taWppRpkwZY+jQoWd93euvv94AjPT09HPGaBjm7zYyMrJAe2H/DQFGUFCQERcXl6998+bNBmBMnTo1X3vr1q2NFi1auB5PnDjRsFqtxl9//ZWv35w5cwzAWLBgQZFilsubhtaLFJOQkBCuvfbaAu179uzhjjvuoHz58thsNjw9PencuTMAW7duPed5mzZtSpUqVVyPvb29qV27dqFDzE5nsVjo06dPvrbGjRvnO3bZsmUEBAQUWMzn9ttvP+f5i+q3336jW7duBarhQ4YMITU11bVoUevWrdm0aRMjR47kp59+IjExscC5WrduTVRUFC+99BJ//vknWVlZxRbnxbj11lsLtA0bNoy0tDRmz57tapsxYwZeXl7ccccdAOzatYtt27bxn//8BwCHw+G69e7dm5iYGLZv3146FyEicoW5mj+bcxe5y606WywWhgwZwv79+/NN+1q4cCG1a9fmuuuuO+O5Fi5ciLe391kr2BeisM/O5ORknnzySWrWrImHhwceHh74+/uTkpKS73ezcOFCunbtSr169c54/urVq3PjjTfywQcfYOQsC/bFF19w4sQJRo0aVazXcr6uvfZaQkJC8rU1atSIFi1aMGPGDFfb1q1bWbNmTb73/ocffqBhw4Y0bdo0398NPXv2vCR2SpDSoURepJhEREQUaEtOTqZjx46sXr2al156iaVLl/LXX3/xzTffAOaiaOdSpkyZAm1eXl5FOtbX1xdvb+8Cx6anp7senzhxgvDw8ALHFtZ2oU6cOFHo+1OhQgXX8wBjx47ljTfe4M8//6RXr16UKVOGbt26sXb
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 32
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:38.803068Z",
"start_time": "2025-06-24T16:40:38.798344Z"
}
},
"cell_type": "code",
"source": [
"class ParameterizedQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self, nqubit, num_circuits):\n",
" super().__init__()\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
"\n",
" def circuit(self, nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作\n",
" cir.barrier()\n",
" for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4可控参数一共16个\n",
" cir.rylayer()\n",
" cir.cnot_ring()\n",
" cir.barrier()\n",
"\n",
" cir.observable(0)\n",
" return cir\n",
"\n",
" def forward(self, x):\n",
" kernel_size, stride = 2, 2\n",
" # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]\n",
" x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)\n",
" w = int((x.shape[-1] - kernel_size) / stride + 1)\n",
" x_reshape = x_unflod.reshape(-1, self.nqubit)\n",
"\n",
" exps = []\n",
" for cir in self.cirs: # out_channels\n",
" cir(x_reshape)\n",
" exp = cir.expectation()\n",
" exps.append(exp)\n",
"\n",
" exps = torch.stack(exps, dim=1)\n",
" exps = exps.reshape(x.shape[0], 3, w, w)\n",
" return exps"
],
"id": "736fe987b84d5891",
"outputs": [],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:39.249632Z",
"start_time": "2025-06-24T16:40:38.925139Z"
}
},
"cell_type": "code",
"source": [
"# 此处我们可视化其中一个量子卷积核的线路结构:\n",
"net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3)\n",
"net.cirs[0].draw()"
],
"id": "e8058c7fde0a012b",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 2210.55x785.944 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABroAAAJxCAYAAAAdC2LsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAzpxJREFUeJzs3Xd4U2X/BvA7abr3oi2FQplt2Xtv2UtkKuBAHKCCgqCgqAxR3IogICgiiOy9l1L2hgJtKdCW7tK9R8bvj/6s9qWFJE1yknPuz3W9l5qc55xv3pzmTs7znOeRaTQaDYiIiIiIiIiIiIiIiIgsjFzoAoiIiIiIiIiIiIiIiIj0wY4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskjs6CIiIiIiIiIiIiIiIiKLxI4uIiIiIiIiIiIiIiIiskgKoQsgqo6LFy/qtH1aWhq2b9+OZ555Bl5eXlq1adeunT6lERGRieiSBfrkAMAsICIyZ8wBIiJp47UhIiLiHV0kKWlpaVi9ejXS0tKELoWIiATAHCAikjbmABERMQuIiMSHHV1ERERERERERERERERkkdjRRURERERERERERERERBaJHV1ERERERERERERERERkkdjRRZLi7OyMAQMGwNnZWehSiIhIAMwBIiJpYw4QERGzgIhIfGQajUYjdBFE+rp48aLRj9GuXTujH4OIiPTHLCAikjbmABGRtDEHiIiId3SRpBQXFyMuLg7FxcVCl0JERAJgDhARSRtzgIiImAVEROLDji6SlOjoaIwcORLR0dFCl0JERAJgDhARSRtzgIiImAVEROKjELoAqpxGo4Gy0HJGlijsbSGTyYQuQzQs7f0HeA4QERmapWUBc4CIyLAsLQcAZgERkSExB4iItMeOLjOlLCzGhvoThC5Da+PvrYe1g53QZYiGpb3/AM8BIiJDs7QsYA4QERmWpeUAwCwgIjIk5gARkfY4dSERERERERERERERERFZJHZ0ERERERERERERERERkUXi1IUkKUFBQbhw4YLQZRARkUCYA0RE0sYcICIiZgERkfjwji4iIiIiIiIiIiIiIiKySOzoIkmJjY3FpEmTEBsbK3QpREQkAOYAEZG0MQeIiIhZQEQkPpy6kCSlsLAQN2/eRGFhodClEBGZRGxiLi7dSsPl2+mIepCNwmIVrOQyuLvYomVjD7QJ8UKrYE84OVgLXapJMAeISGoyc4r/PwfSEBaVidz8Umg0Gjg6KBAc6Ia2TbzQtok3fDzthS7VJJgDRCQ1xSUqXI/MwOXbabgSnoa0rGKUKtWwtbZCXX8ntAn2QpsQTzSq6wqZTCZ0uSbBLCAiKVGrNbgTm43Lt8t+E8Qm5qO4VAVrhRxebrZoE+KFNiFeaN7IA7Y2VkKXqzd2dBEREYlMYZESfx68j+WbwnHpVlqV2/32//+0t7PCswPr442xwWgd4mWaIomIyGg0Gg2On0/C8k3h2PVXLFQqzRPbPNWxJt4YF4wh3QOgUHDiDyIiS3cnJhs/bQ7H2l1RyMoteeL2jeu6YsqYILwwrCHcXGxNUCERERlTZk4x1u6Kwk+bwxEVm1Pldqu33wEAuLvY4KWnG2HKmGA0CHAxVZkGw44uEfHt1AQDts+v8FhpfiFy7ifh3taTCF+zHxqVWqDqyBR4DhBJm0ajwe977mLGV+eRnlWsdbvCIhV+2XEHv+y4g6c61sSqj7oisJazESslY2EOENHl22mY/EkorkVk6NTu6LlEHD2XiLo1nbDyoy7o17mWkSokY2MWEEnbw4xCTPv8HP48eF+ndpEx2Xj7i/OY+8NlfDKlFWY83xRWVhz4YImYA0TSplSq8dVvYViw8ioKi1Rat8vMKcE3627im3U3MX5wfXz/Xkd4utkZsVLDYkeXCEXvPI24o5cAmQz23m5oMLoH2s9/Ea4N/XF21kqhyyMT4DlAJD2Jqfl4dcFp7DsZV639HD2XiGYjt+OLd9phythgyUxfIjbMASLpKSlVYeHKa/hszXWt7uCqSkxiHvq/fgivjGyMr99tD2dHGwNWSabELCCSnu1HY/D6wtN4mFmk9z4KipSY/e1FbD8Wg18XdkdQoJvhCiSTYg4QSc/te5l4cd5JXLxZ9ew+2tiw7x6OnkvEyo+6YHivOgaqzrg4NEOEMm5F4/62UNzfehK3ftqNfYPnIi/hIRo91we2npZ326Eh+fn5Yf78+fDz8xO6FKPiOUAkLZHRWegwfk+1O7n+kV+oxBuLz2LKojNQiWykH3OAOUAkRgWFSgyfdhSLVl2rVifXf/28LRI9XtqP1HRxrV8ilRwAmAVEUvPZ6usYOeNYtTq5/uvcjYfoMH43Tl1JNsj+zIlUsoA5QCQtf19KQscJe6rdyfWPlPRCPD39KL789YZB9mds7OiSAGVhMdKu3IVMLodLHR+hyxGUq6srBg4cCFdXV6FLMSmeA0TidS8uBz1f3o/4lHyD73vllghMWXQGGo1hLpqaA+YAc4BIbIpLVBg+/QgOno43+L6vRqTjqVcPIDNH++lwzZ1UcwBgFhCJ2edrrmPuD5cMvt+cvFIMmHII566nGnzfQpJqFjAHiMTr1JVkDJx6CLn5pQbf9+xvL+KrtWEG36+hsaNLIpzrlgVYUWauwJUIKzMzE1u2bEFmZqbQpZgczwEi8SksUmLwG4eRnGa80fY/b4vEN+tuGm3/psYcYA4Qic20z8/i6LlEo+0/LCoTY2cdF82gBynnAMAsIBKjHcdiMOd7w3dy/SO/UIkhbx1GclqB0Y5halLOAuYAkfgkpuZj2LQjOq3HpatZ31zA3r8fGG3/hsCOLhGysreFrYczbD1d4BYUgA6LJ8OzWT08vBqF3Gjx3XKui5SUFHz55ZdISUkRuhSj4jlAJA0f/ngZkTHZOrW5uHEY4o6Mw8WNw3Q7TnSWjtWZJ+YAc4BITA6ficeqrZE6tdEnB46cTcTP23Q7jrmSSg4AzAIiKUjLLMLrC0/r1EafHEjPKhbVTA9SyQLmAJH4aTQavLrgNDJzSnRqp08WlB3HfGd6UAhdgCmkpaXh888/x44dO5CQkABvb2+MGTMGixYtwiuvvIINGzbg559/xuTJk4Uu1SBazhiNljNGV3gs9sB5nHvvZ4EqIlPjOUAkfmevp+Db33W/08rXywG1fBx1alNUrMJLH4Xi9LohkMlkOh+TTI85QCR+eQWlmPzJKZ3b6ZMDADDzqwsY2LUWavs66dyWhMEsIBK/6UvOITVDtzW59M2BncdjsengfYwbWF/ntiQM5gCR+G3Yd0+v9dr1yYKkhwV4e8k5/PZpD52PZwqi7+i6ceMG+vXrh5SUFDg6OiIkJAQJCQn45ptvEBMTgwcPym65a9GihcCVGs6dDUcRves05AoruDUOQLO3noa9lxuURf/27MptFBh66Avc33EKYT9sL3+863dvwM7bDUfHfypE6WQg2pwDPVfNhEajwd+vfVP+mI2bE57+61tcWrAO97eHClG6wWk0Gly8mYblm8KxLzQO2bklcLBXoE2wJ6aMCcbwXnVgbc2bW8nyLP75Okw
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T16:40:39.356624Z",
"start_time": "2025-06-24T16:40:39.353193Z"
}
},
"cell_type": "code",
"source": [
"# QCCNN整体网络架构\n",
"class QCCNN(nn.Module):\n",
" def __init__(self):\n",
" super(QCCNN, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=1)\n",
" )\n",
"\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(8 * 8 * 3, 128),\n",
" nn.Dropout(0.4),\n",
" nn.ReLU(),\n",
" nn.Linear(128, 10)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = x.reshape(x.size(0), -1)\n",
" x = self.fc(x)\n",
" return x"
],
"id": "e3c6160fff06bed2",
"outputs": [],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:19:26.744108Z",
"start_time": "2025-06-24T16:40:39.450592Z"
}
},
"cell_type": "code",
"source": [
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = QCCNN()\n",
"model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-5) # 添加正则化项\n",
"optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(optim_model.state_dict(), './data/notebook1/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook1/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "34202fca380ee084",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 0.707 Valid Acc: 0.739: 100%|██████████| 300/300 [38:47<00:00, 7.76s/it]\n"
]
}
],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:19:31.162804Z",
"start_time": "2025-06-24T17:19:26.901586Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook1/qccnn_weights.pt', map_location=device)\n",
"qccnn_model = QCCNN()\n",
"qccnn_model.load_state_dict(state_dict)\n",
"qccnn_model.to(device)\n",
"\n",
"test_acc = test_model(qccnn_model, test_loader, device)"
],
"id": "f613b1c9a9ea0cd6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.752\n"
]
}
],
"execution_count": 37
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:19:31.278383Z",
"start_time": "2025-06-24T17:19:31.271353Z"
}
},
"cell_type": "code",
"source": [
"def vgg_block(in_channel,out_channel,num_convs):\n",
" layers = nn.ModuleList()\n",
" assert num_convs >= 1\n",
" layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1))\n",
" layers.append(nn.ReLU())\n",
" for _ in range(num_convs-1):\n",
" layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1))\n",
" layers.append(nn.ReLU())\n",
" layers.append(nn.MaxPool2d(kernel_size=2,stride=2))\n",
" return nn.Sequential(*layers)\n",
"\n",
"VGG = nn.Sequential(\n",
" vgg_block(1,10,3), # 14,14\n",
" vgg_block(10,16,3), # 4 * 4\n",
" nn.Flatten(),\n",
" nn.Linear(16 * 4 * 4, 120),\n",
" nn.Sigmoid(),\n",
" nn.Linear(120, 84),\n",
" nn.Sigmoid(),\n",
" nn.Linear(84,10),\n",
" nn.Softmax(dim=-1)\n",
")"
],
"id": "37cc9edc6c4b035d",
"outputs": [],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:25:57.298293Z",
"start_time": "2025-06-24T17:19:31.391257Z"
}
},
"cell_type": "code",
"source": [
"num_epochs = 300\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"vgg_model = VGG\n",
"vgg_model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(vgg_model.parameters(), lr=1e-5) # 添加正则化项\n",
"vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)\n",
"torch.save(vgg_model.state_dict(), './data/notebook1/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试\n",
"pd.DataFrame(metrics).to_csv('./data/notebook1/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示"
],
"id": "643da0fb0433f438",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train loss: 1.776 Valid Acc: 0.732: 100%|██████████| 300/300 [06:25<00:00, 1.29s/it]\n"
]
}
],
"execution_count": 39
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:25:58.336844Z",
"start_time": "2025-06-24T17:25:57.506934Z"
}
},
"cell_type": "code",
"source": [
"state_dict = torch.load('./data/notebook1/vgg_weights.pt', map_location=device)\n",
"vgg_model = VGG\n",
"vgg_model.load_state_dict(state_dict)\n",
"vgg_model.to(device)\n",
"\n",
"vgg_test_acc = test_model(vgg_model, test_loader, device)"
],
"id": "cc56710965ab7c82",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Acc: 0.742\n"
]
}
],
"execution_count": 40
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:25:58.691203Z",
"start_time": "2025-06-24T17:25:58.488844Z"
}
},
"cell_type": "code",
"source": [
"vgg_data = pd.read_csv('./data/notebook1/vgg_metrics.csv')\n",
"qccnn_data = pd.read_csv('./data/notebook1/qccnn_metrics.csv')\n",
"vgg_epoch = vgg_data['epoch']\n",
"vgg_train_loss = vgg_data['train_loss']\n",
"vgg_valid_loss = vgg_data['valid_loss']\n",
"vgg_train_acc = vgg_data['train_acc']\n",
"vgg_valid_acc = vgg_data['valid_acc']\n",
"\n",
"qccnn_epoch = qccnn_data['epoch']\n",
"qccnn_train_loss = qccnn_data['train_loss']\n",
"qccnn_valid_loss = qccnn_data['valid_loss']\n",
"qccnn_train_acc = qccnn_data['train_acc']\n",
"qccnn_valid_acc = qccnn_data['valid_acc']\n",
"\n",
"# 创建图和Axes对象\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"# 绘制训练损失曲线\n",
"ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss')\n",
"ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss')\n",
"ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss')\n",
"ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss')\n",
"ax1.set_title('Training Loss Curve')\n",
"ax1.set_xlabel('Epoch')\n",
"ax1.set_ylabel('Loss')\n",
"ax1.legend()\n",
"\n",
"# 绘制训练准确率曲线\n",
"ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy')\n",
"ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy')\n",
"ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy')\n",
"ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy')\n",
"ax2.set_title('Training Accuracy Curve')\n",
"ax2.set_xlabel('Epoch')\n",
"ax2.set_ylabel('Accuracy')\n",
"ax2.legend()\n",
"\n",
"plt.show()"
],
"id": "8e450f8cfb2812d2",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+kAAAHUCAYAAABGRmklAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd0FFUbwOHf1vTeCJAGhITQew2996JUgUAsNAEpUlRAUVCQKgKikIB0pEov0qV3CJ2QEAiEkl52s7vz/bGynzEBQt0E7nPOHtnZO3femcTMvnObTJIkCUEQBEEQBEEQBEEQzE5u7gAEQRAEQRAEQRAEQTASSbogCIIgCIIgCIIg5BEiSRcEQRAEQRAEQRCEPEIk6YIgCIIgCIIgCIKQR4gkXRAEQRAEQRAEQRDyCJGkC4IgCIIgCIIgCEIeIZJ0QRAEQRAEQRAEQcgjRJIuCIIgCIIgCIIgCHmESNIFQRAEQRAEQRAEIY8QSbrwTpPJZLl67dmz56WOM27cOGQy2Qvtu2fPnlcSw8sc+48//njjx34RZ8+epVevXvj5+WFpaYmtrS0VKlRg0qRJPHr0yNzhCYIgCLkg7s25t2HDBmQyGS4uLmg0GrPGkh8lJSXx3XffUalSJezt7bGwsMDX15fevXtz8uRJc4cnvMOU5g5AEMzp0KFDWd6PHz+e3bt389dff2XZHhQU9FLH+fDDD2natOkL7VuhQgUOHTr00jG87X799Vf69etHQEAAw4cPJygoiMzMTI4fP87cuXM5dOgQa9euNXeYgiAIwjOIe3PuzZ8/H4BHjx6xbt06OnXqZNZ48pPr16/TuHFj4uLi6NOnD19//TW2trbcvHmTlStXUrFiRRISEnBwcDB3qMI7SCZJkmTuIAQhrwgJCeGPP/4gJSXlqeXS0tKwtrZ+Q1GZz549e6hXrx6rVq3ivffeM3c4T3To0CGCg4Np1KgR69atw8LCIsvnWq2WrVu30rp165c+Vnp6OpaWli/c+iIIgiA8H3Fvztndu3fx8vKidu3a/P333wQHB7N9+3Zzh5WjvPaz0ev1lC9fnqioKA4ePEipUqWyldmyZQt16tR56bglSSIjIwMrK6uXqkd4t4ju7oLwDHXr1qVUqVLs27ePGjVqYG1tTe/evQFYsWIFjRs3xtPTEysrK0qUKMHIkSNJTU3NUkdOXep8fX1p2bIlW7dupUKFClhZWREYGMiCBQuylMupS11ISAi2trZcu3aN5s2bY2tri5eXF0OHDs3W3S0mJob33nsPOzs7HB0d6datG8eOHUMmkxEeHv5KrtH58+dp06YNTk5OWFpaUq5cORYuXJiljMFg4NtvvyUgIAArKyscHR0pU6YMM2bMMJW5f/8+H3/8MV5eXlhYWODm5kbNmjXZuXPnU48/YcIEZDIZ8+bNy5agA6jV6iwJukwmY9y4cdnK+fr6EhISYnofHh6OTCZj+/bt9O7dGzc3N6ytrVmxYgUymYxdu3Zlq2POnDnIZDLOnj1r2nb8+HFat26Ns7MzlpaWlC9fnpUrVz71nARBEIQnE/dmWLhwITqdjs8++4z27duza9cuoqKispVLSEhg6NChFClSBAsLC9zd3WnevDmXLl0yldFoNHzzzTeUKFECS0tLXFxcqFevHn///TcAN2/efGJs/72nPr6uJ0+e5L333sPJyYmiRYsCxvth586d8fX1xcrKCl9fX7p06ZJj3Ldv3zZ9J1Cr1RQsWJD33nuPe/fukZKSgqOjI5988km2/W7evIlCoWDy5MlPvHbr1q3j3LlzjBo1KscEHaBZs2amBD0kJARfX99sZXL6HZLJZAwYMIC5c+dSokQJLCws+O2333B3d6d79+7Z6khISMDKyoohQ4aYtiUlJTFs2DD8/PxQq9UUKlSIwYMHZ/sdFt5eoru7IORCbGwsH3zwAZ9//jkTJkxALjc+37p69SrNmzdn8ODB2NjYcOnSJX744QeOHj2arVteTs6cOcPQoUMZOXIkHh4e/Pbbb4SGhlKsWDFq16791H0zMzNp3bo1oaGhDB06lH379jF+/HgcHBwYM2YMAKmpqdSrV49Hjx7xww8/UKxYMbZu3fpKu8NdvnyZGjVq4O7uzsyZM3FxcWHx4sWEhIRw7949Pv/8cwAmTZrEuHHj+PLLL6lduzaZmZlcunSJhIQEU13du3fn5MmTfPfddxQvXpyEhAROnjzJw4cPn3h8vV7PX3/9RcWKFfHy8npl5/VvvXv3pkWLFvz++++kpqbSsmVL3N3dCQsLo0GDBlnKhoeHU6FCBcqUKQPA7t27adq0KVWrVmXu3Lk4ODiwfPlyOnXqRFpaWpaHAoIgCELuvev35gULFuDp6UmzZs2wsrJi6dKlhIeHM3bsWFOZ5ORkatWqxc2bNxkxYgRVq1YlJSWFffv2ERsbS2BgIDqdjmbNmrF//34GDx5M/fr10el0HD58mOjoaGrUqPFccT3Wvn17OnfuTJ8+fUzJ5c2bNwkICKBz5844OzsTGxvLnDlzqFy5MhEREbi6ugLGBL1y5cpkZmYyevRoypQpw8OHD9m2bRvx8fF4eHjQu3dv5s2bx6RJk7J0SZ89ezZqtdr00CYnj3sctG3b9oXO7VnWrVvH/v37GTNmDAUKFMDd3Z3IyEjmzp3Lzz//jL29vanssmXLyMjIoFevXoCx10GdOnWIiYkxnfuFCxcYM2YM586dY+fOnaI337tAEgTBpGfPnpKNjU2WbXXq1JEAadeuXU/d12AwSJmZmdLevXslQDpz5ozps7Fjx0r//d/Nx8dHsrS0lKKiokzb0tPTJWdnZ+mTTz4xbdu9e7cESLt3784SJyCtXLkyS53NmzeXAgICTO9//vlnCZC2bNmSpdwnn3wiAVJYWNhTz+nxsVetWvXEMp07d5YsLCyk6OjoLNubNWsmWVtbSwkJCZIkSVLLli2lcuXKPfV4tra20uDBg59a5r/u3r0rAVLnzp1zvQ8gjR07Ntt2Hx8fqWfPnqb3YWFhEiD16NEjW9khQ4ZIVlZWpvOTJEmKiIiQAOmnn34ybQsMDJTKly8vZWZmZtm/ZcuWkqenp6TX63MdtyAIwrtI3Juz27dvnwRII0eONJ2nn5+f5OPjIxkMBlO5b775RgKkHTt2PLGuRYsWSYD066+/PrFMZGTkE2P77z318XUdM2bMM89Dp9NJKSkpko2NjTRjxgzT9t69e0sqlUqKiIh44r7Xr1+X5HK5NG3aNNO29PR0ycXFRerVq9dTj9u0aVMJkDIyMp4ZoyQZf7Y+Pj7Ztuf0OwRIDg4O0qNHj7JsP3v2rARI8+bNy7K9SpUqUsWKFU3vJ06cKMnlcunYsWNZyv3xxx8SIG3evDlXMQv5m+juLgi54OTkRP369bNtv3HjBl27dqVAgQIoFApUKhV16tQB4OLFi8+st1y5cnh7e5veW1paUrx48Ry7ff2XTCajVatWWbaVKVMmy7579+7Fzs4u28Q4Xbp0eWb9ufXXX3/RoEGDbK3YISEhpKWlmSYAqlKlCmfOnKFfv35s27aNpKSkbHVVqVKF8PBwvv32Ww4fPkxmZuYri/NldOjQIdu23r17k56ezooVK0zbwsLCsLCwoGvXrgBcu3aNS5cu0a1bNwB0Op3p1bx5c2JjY7l8+fKbOQlBEIS3zLt8b348Ydzj1mKZTEZISAhRUVFZhmJt2bKF4sWL07BhwyfWtWXLFiwtLZ/a8vwicrp3pqSkMGLECIoVK4ZSqUSpVGJra0tqamqWn82WLVuoV68eJUqUeGL9RYoUoWXLlsyePRvpnym2li5dysOHDxkwYMArPZfnVb9+fZycnLJsK126NBUrViQsLMy07eLFixw9ejTLtd+4cSOlSpWiXLlyWb43NGnSJE+sKCC8GSJJF4Rc8PT0zLYtJSWF4OBgjhw5wrfffsuePXs4duwYa9asAYwTjD2Li4tLtm0WFha52tfa2hpLS8ts+2ZkZJjeP3z4EA8Pj2z75rTtRT18+DDH61OwYEHT5wCjRo3ixx9/5PDhwzRr1gwXFxcaNGjA8ePHTfusWLGCnj178ttvv1G
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 41
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:25:58.850907Z",
"start_time": "2025-06-24T17:25:58.833078Z"
}
},
"cell_type": "code",
"source": [
"# 这里我们对比不同模型之间可训练参数量的区别\n",
"\n",
"def count_parameters(model):\n",
" \"\"\"\n",
" 计算模型的参数数量\n",
" \"\"\"\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"number_params_VGG = count_parameters(VGG)\n",
"number_params_QCCNN = count_parameters(QCCNN())\n",
"print(f'VGG 模型可训练参数量:{number_params_VGG}\\t QCCNN模型可训练参数量{number_params_QCCNN}')"
],
"id": "9675ba847f4a998d",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"VGG 模型可训练参数量49870\t QCCNN模型可训练参数量26042\n"
]
}
],
"execution_count": 42
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}