DeepQuantom-CNN/Modify.ipynb

697 lines
103 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.323528Z",
"start_time": "2025-06-24T17:37:28.321407Z"
}
},
"cell_type": "code",
"source": "# Modify.py\n",
"id": "98da35e8f6af6b7a",
"outputs": [],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.355899Z",
"start_time": "2025-06-24T17:37:28.352650Z"
}
},
"cell_type": "code",
"source": [
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import deepquantum as dq\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision.transforms as transforms\n",
"from torchvision.datasets import FashionMNIST\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"from multiprocessing import freeze_support\n"
],
"id": "fba02718b5ce470f",
"outputs": [],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.376643Z",
"start_time": "2025-06-24T17:37:28.373463Z"
}
},
"cell_type": "code",
"source": [
"def seed_torch(seed=1024):\n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True\n"
],
"id": "e21c1ebc100b5079",
"outputs": [],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.406416Z",
"start_time": "2025-06-24T17:37:28.403413Z"
}
},
"cell_type": "code",
"source": [
"def calculate_score(y_true, y_preds):\n",
" preds_prob = torch.softmax(y_preds, dim=1)\n",
" preds_class = torch.argmax(preds_prob, dim=1)\n",
" correct = (preds_class == y_true).float()\n",
" return (correct.sum() / len(correct)).cpu().numpy()\n"
],
"id": "f70cc647d264747",
"outputs": [],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.439020Z",
"start_time": "2025-06-24T17:37:28.433485Z"
}
},
"cell_type": "code",
"source": [
"def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):\n",
" model.to(device)\n",
" best_acc = 0.0\n",
" metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}\n",
"\n",
" for epoch in range(1, num_epochs + 1):\n",
" # --- 训练阶段 ---\n",
" model.train()\n",
" running_loss, running_acc = 0.0, 0.0\n",
" for imgs, labels in train_loader:\n",
" imgs, labels = imgs.to(device), labels.to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(imgs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" running_acc += calculate_score(labels, outputs)\n",
"\n",
" train_loss = running_loss / len(train_loader)\n",
" train_acc = running_acc / len(train_loader)\n",
" scheduler.step()\n",
"\n",
" # --- 验证阶段 ---\n",
" model.eval()\n",
" val_loss, val_acc = 0.0, 0.0\n",
" with torch.no_grad():\n",
" for imgs, labels in valid_loader:\n",
" imgs, labels = imgs.to(device), labels.to(device)\n",
" outputs = model(imgs)\n",
" loss = criterion(outputs, labels)\n",
" val_loss += loss.item()\n",
" val_acc += calculate_score(labels, outputs)\n",
"\n",
" valid_loss = val_loss / len(valid_loader)\n",
" valid_acc = val_acc / len(valid_loader)\n",
"\n",
" metrics['epoch'].append(epoch)\n",
" metrics['train_loss'].append(train_loss)\n",
" metrics['valid_loss'].append(valid_loss)\n",
" metrics['train_acc'].append(train_acc)\n",
" metrics['valid_acc'].append(valid_acc)\n",
"\n",
" tqdm.write(f\"[{save_path}] Epoch {epoch}/{num_epochs} \"\n",
" f\"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}\")\n",
"\n",
" if valid_acc > best_acc:\n",
" best_acc = valid_acc\n",
" torch.save(model.state_dict(), save_path)\n",
"\n",
" return model, metrics\n"
],
"id": "7cfe89d63ab0e68e",
"outputs": [],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.467170Z",
"start_time": "2025-06-24T17:37:28.463678Z"
}
},
"cell_type": "code",
"source": [
"def test_model(model, test_loader, device):\n",
" model.to(device).eval()\n",
" acc = 0.0\n",
" with torch.no_grad():\n",
" for imgs, labels in test_loader:\n",
" imgs, labels = imgs.to(device), labels.to(device)\n",
" outputs = model(imgs)\n",
" acc += calculate_score(labels, outputs)\n",
" acc /= len(test_loader)\n",
" print(f\"Test Accuracy: {acc:.4f}\")\n",
" return acc\n"
],
"id": "235ac16bd786e65f",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T17:37:28.506986Z",
"start_time": "2025-06-24T17:37:28.496003Z"
}
},
"cell_type": "code",
"source": [
"singlegate_list = ['rx','ry','rz','s','t','p','u3']\n",
"doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']\n",
"\n",
"class RandomQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self, nqubit, num_circuits, seed=1024):\n",
" super().__init__()\n",
" random.seed(seed)\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
" def circuit(self, nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True); cir.barrier()\n",
" for _ in range(3):\n",
" for i in range(nqubit):\n",
" getattr(cir, random.choice(singlegate_list))(i)\n",
" c,t = random.sample(range(nqubit),2)\n",
" gate = random.choice(doublegate_list)\n",
" if gate[0] in ['r','s']:\n",
" getattr(cir, gate)([c,t])\n",
" else:\n",
" getattr(cir, gate)(c,t)\n",
" cir.barrier()\n",
" cir.observable(0)\n",
" return cir\n",
" def forward(self, x):\n",
" k,s = 2,2\n",
" x_unf = x.unfold(2,k,s).unfold(3,k,s)\n",
" w = (x.shape[-1]-k)//s + 1\n",
" x_r = x_unf.reshape(-1, self.nqubit)\n",
" exps = []\n",
" for cir in self.cirs:\n",
" cir(x_r)\n",
" exps.append(cir.expectation())\n",
" exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)\n",
" return exps\n",
"\n",
"class RandomQCCNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = nn.Sequential(\n",
" RandomQuantumConvolutionalLayer(4,3,seed=1024),\n",
" nn.ReLU(), nn.MaxPool2d(2,1),\n",
" nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)\n",
" )\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(6*6*6,1024), nn.Dropout(0.4),\n",
" nn.Linear(1024,10)\n",
" )\n",
" def forward(self,x):\n",
" x = self.conv(x)\n",
" x = x.view(x.size(0),-1)\n",
" return self.fc(x)\n",
"\n",
"class ParameterizedQuantumConvolutionalLayer(nn.Module):\n",
" def __init__(self,nqubit,num_circuits):\n",
" super().__init__()\n",
" self.nqubit = nqubit\n",
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
" def circuit(self,nqubit):\n",
" cir = dq.QubitCircuit(nqubit)\n",
" cir.rxlayer(encode=True); cir.barrier()\n",
" for _ in range(4):\n",
" cir.rylayer(); cir.cnot_ring(); cir.barrier()\n",
" cir.observable(0)\n",
" return cir\n",
" def forward(self,x):\n",
" k,s = 2,2\n",
" x_unf = x.unfold(2,k,s).unfold(3,k,s)\n",
" w = (x.shape[-1]-k)//s +1\n",
" x_r = x_unf.reshape(-1,self.nqubit)\n",
" exps = []\n",
" for cir in self.cirs:\n",
" cir(x_r); exps.append(cir.expectation())\n",
" exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)\n",
" return exps\n",
"\n",
"class QCCNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = nn.Sequential(\n",
" ParameterizedQuantumConvolutionalLayer(4,3),\n",
" nn.ReLU(), nn.MaxPool2d(2,1)\n",
" )\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),\n",
" nn.Linear(128,10)\n",
" )\n",
" def forward(self,x):\n",
" x = self.conv(x); x = x.view(x.size(0),-1)\n",
" return self.fc(x)\n",
"\n",
"def vgg_block(in_c,out_c,n_convs):\n",
" layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]\n",
" for _ in range(n_convs-1):\n",
" layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]\n",
" layers.append(nn.MaxPool2d(2,2))\n",
" return nn.Sequential(*layers)\n",
"\n",
"VGG = nn.Sequential(\n",
" vgg_block(1,10,3),\n",
" vgg_block(10,16,3),\n",
" nn.Flatten(),\n",
" nn.Linear(16*4*4,120), nn.Sigmoid(),\n",
" nn.Linear(120,84), nn.Sigmoid(),\n",
" nn.Linear(84,10), nn.Softmax(dim=-1)\n",
")\n"
],
"id": "c996822b3d5f8305",
"outputs": [],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-24T18:51:44.622880Z",
"start_time": "2025-06-24T17:37:28.529442Z"
}
},
"cell_type": "code",
"source": [
"if __name__ == '__main__':\n",
" freeze_support()\n",
"\n",
" # 数据增广与加载\n",
" train_transform = transforms.Compose([\n",
" transforms.Resize((18, 18)),\n",
" transforms.RandomRotation(15),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.RandomVerticalFlip(0.3),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
" ])\n",
" eval_transform = transforms.Compose([\n",
" transforms.Resize((18, 18)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
" ])\n",
"\n",
" full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)\n",
" test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)\n",
" train_size = int(0.8 * len(full_train))\n",
" valid_size = len(full_train) - train_size\n",
" train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])\n",
" valid_ds.dataset.transform = eval_transform\n",
"\n",
" batch_size = 128\n",
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)\n",
" valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)\n",
" test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)\n",
"\n",
" # 三种模型配置\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" models = {\n",
" 'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),\n",
" 'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),\n",
" 'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')\n",
" }\n",
"\n",
" all_metrics = {}\n",
" for name, (model, lr, save_path) in models.items():\n",
" seed_torch(1024)\n",
" model = model.to(device)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)\n",
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n",
"\n",
" print(f\"\\n=== Training {name} ===\")\n",
" _, metrics = train_model(\n",
" model, criterion, optimizer, scheduler,\n",
" train_loader, valid_loader,\n",
" num_epochs=50, device=device, save_path=save_path\n",
" )\n",
" all_metrics[name] = metrics\n",
" pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)\n",
"\n",
" # 测试与可视化\n",
" plt.figure(figsize=(12,5))\n",
" for i,(name,metrics) in enumerate(all_metrics.items(),1):\n",
" model, _, save_path = models[name]\n",
" best_model = model.to(device)\n",
" best_model.load_state_dict(torch.load(save_path))\n",
" print(f\"\\n--- Testing {name} ---\")\n",
" test_model(best_model, test_loader, device)\n",
"\n",
" plt.subplot(1,3,i)\n",
" plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n",
" plt.xlabel('Epoch'); plt.ylabel('Valid Acc')\n",
" plt.title(name); plt.legend()\n",
"\n",
" plt.tight_layout(); plt.show()\n",
"\n",
" # 参数量统计\n",
" def count_parameters(m):\n",
" return sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
"\n",
" print(\"\\nParameter Counts:\")\n",
" for name,(model,_,_) in models.items():\n",
" print(f\"{name}: {count_parameters(model)}\")\n"
],
"id": "2d6b93bb78001086",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"=== Training random_qccnn ===\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 1/50 Train Acc: 0.6377 Valid Acc: 0.7426\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 2/50 Train Acc: 0.7589 Valid Acc: 0.7799\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 3/50 Train Acc: 0.7830 Valid Acc: 0.7955\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 4/50 Train Acc: 0.7928 Valid Acc: 0.8010\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 5/50 Train Acc: 0.7997 Valid Acc: 0.8065\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 6/50 Train Acc: 0.8044 Valid Acc: 0.8140\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 7/50 Train Acc: 0.8097 Valid Acc: 0.8157\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 8/50 Train Acc: 0.8155 Valid Acc: 0.8163\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 9/50 Train Acc: 0.8162 Valid Acc: 0.8159\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 10/50 Train Acc: 0.8169 Valid Acc: 0.8160\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 11/50 Train Acc: 0.8210 Valid Acc: 0.8269\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 12/50 Train Acc: 0.8210 Valid Acc: 0.8266\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 13/50 Train Acc: 0.8241 Valid Acc: 0.8212\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 14/50 Train Acc: 0.8240 Valid Acc: 0.8264\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 15/50 Train Acc: 0.8245 Valid Acc: 0.8231\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 16/50 Train Acc: 0.8270 Valid Acc: 0.8291\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 17/50 Train Acc: 0.8274 Valid Acc: 0.8297\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 18/50 Train Acc: 0.8281 Valid Acc: 0.8338\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 19/50 Train Acc: 0.8295 Valid Acc: 0.8291\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 20/50 Train Acc: 0.8306 Valid Acc: 0.8304\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 21/50 Train Acc: 0.8320 Valid Acc: 0.8280\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 22/50 Train Acc: 0.8316 Valid Acc: 0.8293\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 23/50 Train Acc: 0.8315 Valid Acc: 0.8298\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 24/50 Train Acc: 0.8329 Valid Acc: 0.8282\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 25/50 Train Acc: 0.8321 Valid Acc: 0.8313\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 26/50 Train Acc: 0.8350 Valid Acc: 0.8311\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 27/50 Train Acc: 0.8343 Valid Acc: 0.8335\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 28/50 Train Acc: 0.8347 Valid Acc: 0.8320\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 29/50 Train Acc: 0.8354 Valid Acc: 0.8333\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 30/50 Train Acc: 0.8359 Valid Acc: 0.8314\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 31/50 Train Acc: 0.8380 Valid Acc: 0.8348\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 32/50 Train Acc: 0.8369 Valid Acc: 0.8330\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 33/50 Train Acc: 0.8375 Valid Acc: 0.8367\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 34/50 Train Acc: 0.8373 Valid Acc: 0.8315\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 35/50 Train Acc: 0.8381 Valid Acc: 0.8352\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 36/50 Train Acc: 0.8393 Valid Acc: 0.8374\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 37/50 Train Acc: 0.8384 Valid Acc: 0.8348\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 38/50 Train Acc: 0.8398 Valid Acc: 0.8355\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 39/50 Train Acc: 0.8402 Valid Acc: 0.8365\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 40/50 Train Acc: 0.8400 Valid Acc: 0.8346\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 41/50 Train Acc: 0.8411 Valid Acc: 0.8374\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 42/50 Train Acc: 0.8409 Valid Acc: 0.8373\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 43/50 Train Acc: 0.8415 Valid Acc: 0.8364\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 44/50 Train Acc: 0.8414 Valid Acc: 0.8359\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 45/50 Train Acc: 0.8409 Valid Acc: 0.8364\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 46/50 Train Acc: 0.8419 Valid Acc: 0.8371\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 47/50 Train Acc: 0.8422 Valid Acc: 0.8369\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 48/50 Train Acc: 0.8421 Valid Acc: 0.8360\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 49/50 Train Acc: 0.8417 Valid Acc: 0.8369\n",
"[./data/notebook2/random_qccnn_best.pt] Epoch 50/50 Train Acc: 0.8408 Valid Acc: 0.8369\n",
"\n",
"=== Training qccnn ===\n",
"[./data/notebook2/qccnn_best.pt] Epoch 1/50 Train Acc: 0.3113 Valid Acc: 0.5076\n",
"[./data/notebook2/qccnn_best.pt] Epoch 2/50 Train Acc: 0.4621 Valid Acc: 0.5548\n",
"[./data/notebook2/qccnn_best.pt] Epoch 3/50 Train Acc: 0.5201 Valid Acc: 0.5892\n",
"[./data/notebook2/qccnn_best.pt] Epoch 4/50 Train Acc: 0.5642 Valid Acc: 0.6311\n",
"[./data/notebook2/qccnn_best.pt] Epoch 5/50 Train Acc: 0.6056 Valid Acc: 0.6599\n",
"[./data/notebook2/qccnn_best.pt] Epoch 6/50 Train Acc: 0.6317 Valid Acc: 0.6804\n",
"[./data/notebook2/qccnn_best.pt] Epoch 7/50 Train Acc: 0.6514 Valid Acc: 0.6929\n",
"[./data/notebook2/qccnn_best.pt] Epoch 8/50 Train Acc: 0.6678 Valid Acc: 0.7008\n",
"[./data/notebook2/qccnn_best.pt] Epoch 9/50 Train Acc: 0.6809 Valid Acc: 0.7096\n",
"[./data/notebook2/qccnn_best.pt] Epoch 10/50 Train Acc: 0.6877 Valid Acc: 0.7186\n",
"[./data/notebook2/qccnn_best.pt] Epoch 11/50 Train Acc: 0.6976 Valid Acc: 0.7247\n",
"[./data/notebook2/qccnn_best.pt] Epoch 12/50 Train Acc: 0.7031 Valid Acc: 0.7303\n",
"[./data/notebook2/qccnn_best.pt] Epoch 13/50 Train Acc: 0.7119 Valid Acc: 0.7317\n",
"[./data/notebook2/qccnn_best.pt] Epoch 14/50 Train Acc: 0.7164 Valid Acc: 0.7404\n",
"[./data/notebook2/qccnn_best.pt] Epoch 15/50 Train Acc: 0.7211 Valid Acc: 0.7438\n",
"[./data/notebook2/qccnn_best.pt] Epoch 16/50 Train Acc: 0.7249 Valid Acc: 0.7481\n",
"[./data/notebook2/qccnn_best.pt] Epoch 17/50 Train Acc: 0.7294 Valid Acc: 0.7500\n",
"[./data/notebook2/qccnn_best.pt] Epoch 18/50 Train Acc: 0.7345 Valid Acc: 0.7518\n",
"[./data/notebook2/qccnn_best.pt] Epoch 19/50 Train Acc: 0.7350 Valid Acc: 0.7550\n",
"[./data/notebook2/qccnn_best.pt] Epoch 20/50 Train Acc: 0.7391 Valid Acc: 0.7587\n",
"[./data/notebook2/qccnn_best.pt] Epoch 21/50 Train Acc: 0.7434 Valid Acc: 0.7608\n",
"[./data/notebook2/qccnn_best.pt] Epoch 22/50 Train Acc: 0.7443 Valid Acc: 0.7634\n",
"[./data/notebook2/qccnn_best.pt] Epoch 23/50 Train Acc: 0.7481 Valid Acc: 0.7654\n",
"[./data/notebook2/qccnn_best.pt] Epoch 24/50 Train Acc: 0.7498 Valid Acc: 0.7683\n",
"[./data/notebook2/qccnn_best.pt] Epoch 25/50 Train Acc: 0.7529 Valid Acc: 0.7696\n",
"[./data/notebook2/qccnn_best.pt] Epoch 26/50 Train Acc: 0.7547 Valid Acc: 0.7708\n",
"[./data/notebook2/qccnn_best.pt] Epoch 27/50 Train Acc: 0.7547 Valid Acc: 0.7723\n",
"[./data/notebook2/qccnn_best.pt] Epoch 28/50 Train Acc: 0.7580 Valid Acc: 0.7736\n",
"[./data/notebook2/qccnn_best.pt] Epoch 29/50 Train Acc: 0.7571 Valid Acc: 0.7749\n",
"[./data/notebook2/qccnn_best.pt] Epoch 30/50 Train Acc: 0.7602 Valid Acc: 0.7760\n",
"[./data/notebook2/qccnn_best.pt] Epoch 31/50 Train Acc: 0.7610 Valid Acc: 0.7767\n",
"[./data/notebook2/qccnn_best.pt] Epoch 32/50 Train Acc: 0.7618 Valid Acc: 0.7764\n",
"[./data/notebook2/qccnn_best.pt] Epoch 33/50 Train Acc: 0.7630 Valid Acc: 0.7784\n",
"[./data/notebook2/qccnn_best.pt] Epoch 34/50 Train Acc: 0.7632 Valid Acc: 0.7791\n",
"[./data/notebook2/qccnn_best.pt] Epoch 35/50 Train Acc: 0.7627 Valid Acc: 0.7786\n",
"[./data/notebook2/qccnn_best.pt] Epoch 36/50 Train Acc: 0.7653 Valid Acc: 0.7803\n",
"[./data/notebook2/qccnn_best.pt] Epoch 37/50 Train Acc: 0.7640 Valid Acc: 0.7811\n",
"[./data/notebook2/qccnn_best.pt] Epoch 38/50 Train Acc: 0.7674 Valid Acc: 0.7799\n",
"[./data/notebook2/qccnn_best.pt] Epoch 39/50 Train Acc: 0.7649 Valid Acc: 0.7816\n",
"[./data/notebook2/qccnn_best.pt] Epoch 40/50 Train Acc: 0.7661 Valid Acc: 0.7823\n",
"[./data/notebook2/qccnn_best.pt] Epoch 41/50 Train Acc: 0.7668 Valid Acc: 0.7818\n",
"[./data/notebook2/qccnn_best.pt] Epoch 42/50 Train Acc: 0.7662 Valid Acc: 0.7818\n",
"[./data/notebook2/qccnn_best.pt] Epoch 43/50 Train Acc: 0.7668 Valid Acc: 0.7824\n",
"[./data/notebook2/qccnn_best.pt] Epoch 44/50 Train Acc: 0.7678 Valid Acc: 0.7825\n",
"[./data/notebook2/qccnn_best.pt] Epoch 45/50 Train Acc: 0.7677 Valid Acc: 0.7826\n",
"[./data/notebook2/qccnn_best.pt] Epoch 46/50 Train Acc: 0.7666 Valid Acc: 0.7827\n",
"[./data/notebook2/qccnn_best.pt] Epoch 47/50 Train Acc: 0.7689 Valid Acc: 0.7828\n",
"[./data/notebook2/qccnn_best.pt] Epoch 48/50 Train Acc: 0.7675 Valid Acc: 0.7827\n",
"[./data/notebook2/qccnn_best.pt] Epoch 49/50 Train Acc: 0.7678 Valid Acc: 0.7828\n",
"[./data/notebook2/qccnn_best.pt] Epoch 50/50 Train Acc: 0.7677 Valid Acc: 0.7826\n",
"\n",
"=== Training vgg ===\n",
"[./data/notebook2/vgg_best.pt] Epoch 1/50 Train Acc: 0.2536 Valid Acc: 0.4195\n",
"[./data/notebook2/vgg_best.pt] Epoch 2/50 Train Acc: 0.4692 Valid Acc: 0.5073\n",
"[./data/notebook2/vgg_best.pt] Epoch 3/50 Train Acc: 0.5270 Valid Acc: 0.5266\n",
"[./data/notebook2/vgg_best.pt] Epoch 4/50 Train Acc: 0.5354 Valid Acc: 0.5351\n",
"[./data/notebook2/vgg_best.pt] Epoch 5/50 Train Acc: 0.5935 Valid Acc: 0.6175\n",
"[./data/notebook2/vgg_best.pt] Epoch 6/50 Train Acc: 0.6366 Valid Acc: 0.6626\n",
"[./data/notebook2/vgg_best.pt] Epoch 7/50 Train Acc: 0.6866 Valid Acc: 0.7303\n",
"[./data/notebook2/vgg_best.pt] Epoch 8/50 Train Acc: 0.7415 Valid Acc: 0.7512\n",
"[./data/notebook2/vgg_best.pt] Epoch 9/50 Train Acc: 0.7542 Valid Acc: 0.7582\n",
"[./data/notebook2/vgg_best.pt] Epoch 10/50 Train Acc: 0.7608 Valid Acc: 0.7646\n",
"[./data/notebook2/vgg_best.pt] Epoch 11/50 Train Acc: 0.7677 Valid Acc: 0.7719\n",
"[./data/notebook2/vgg_best.pt] Epoch 12/50 Train Acc: 0.7705 Valid Acc: 0.7731\n",
"[./data/notebook2/vgg_best.pt] Epoch 13/50 Train Acc: 0.7716 Valid Acc: 0.7736\n",
"[./data/notebook2/vgg_best.pt] Epoch 14/50 Train Acc: 0.7748 Valid Acc: 0.7749\n",
"[./data/notebook2/vgg_best.pt] Epoch 15/50 Train Acc: 0.7763 Valid Acc: 0.7751\n",
"[./data/notebook2/vgg_best.pt] Epoch 16/50 Train Acc: 0.7790 Valid Acc: 0.7768\n",
"[./data/notebook2/vgg_best.pt] Epoch 17/50 Train Acc: 0.7801 Valid Acc: 0.7802\n",
"[./data/notebook2/vgg_best.pt] Epoch 18/50 Train Acc: 0.7818 Valid Acc: 0.7828\n",
"[./data/notebook2/vgg_best.pt] Epoch 19/50 Train Acc: 0.7827 Valid Acc: 0.7816\n",
"[./data/notebook2/vgg_best.pt] Epoch 20/50 Train Acc: 0.7827 Valid Acc: 0.7840\n",
"[./data/notebook2/vgg_best.pt] Epoch 21/50 Train Acc: 0.7849 Valid Acc: 0.7858\n",
"[./data/notebook2/vgg_best.pt] Epoch 22/50 Train Acc: 0.7877 Valid Acc: 0.7849\n",
"[./data/notebook2/vgg_best.pt] Epoch 23/50 Train Acc: 0.7880 Valid Acc: 0.7849\n",
"[./data/notebook2/vgg_best.pt] Epoch 24/50 Train Acc: 0.7891 Valid Acc: 0.7860\n",
"[./data/notebook2/vgg_best.pt] Epoch 25/50 Train Acc: 0.7903 Valid Acc: 0.7884\n",
"[./data/notebook2/vgg_best.pt] Epoch 26/50 Train Acc: 0.7910 Valid Acc: 0.7900\n",
"[./data/notebook2/vgg_best.pt] Epoch 27/50 Train Acc: 0.7919 Valid Acc: 0.7886\n",
"[./data/notebook2/vgg_best.pt] Epoch 28/50 Train Acc: 0.7937 Valid Acc: 0.7906\n",
"[./data/notebook2/vgg_best.pt] Epoch 29/50 Train Acc: 0.7934 Valid Acc: 0.7876\n",
"[./data/notebook2/vgg_best.pt] Epoch 30/50 Train Acc: 0.7946 Valid Acc: 0.7902\n",
"[./data/notebook2/vgg_best.pt] Epoch 31/50 Train Acc: 0.7959 Valid Acc: 0.7933\n",
"[./data/notebook2/vgg_best.pt] Epoch 32/50 Train Acc: 0.7956 Valid Acc: 0.7933\n",
"[./data/notebook2/vgg_best.pt] Epoch 33/50 Train Acc: 0.7971 Valid Acc: 0.7920\n",
"[./data/notebook2/vgg_best.pt] Epoch 34/50 Train Acc: 0.7974 Valid Acc: 0.7932\n",
"[./data/notebook2/vgg_best.pt] Epoch 35/50 Train Acc: 0.7980 Valid Acc: 0.7948\n",
"[./data/notebook2/vgg_best.pt] Epoch 36/50 Train Acc: 0.7985 Valid Acc: 0.7950\n",
"[./data/notebook2/vgg_best.pt] Epoch 37/50 Train Acc: 0.7990 Valid Acc: 0.7959\n",
"[./data/notebook2/vgg_best.pt] Epoch 38/50 Train Acc: 0.7992 Valid Acc: 0.7949\n",
"[./data/notebook2/vgg_best.pt] Epoch 39/50 Train Acc: 0.8001 Valid Acc: 0.7960\n",
"[./data/notebook2/vgg_best.pt] Epoch 40/50 Train Acc: 0.8006 Valid Acc: 0.7957\n",
"[./data/notebook2/vgg_best.pt] Epoch 41/50 Train Acc: 0.8007 Valid Acc: 0.7963\n",
"[./data/notebook2/vgg_best.pt] Epoch 42/50 Train Acc: 0.8015 Valid Acc: 0.7959\n",
"[./data/notebook2/vgg_best.pt] Epoch 43/50 Train Acc: 0.8014 Valid Acc: 0.7962\n",
"[./data/notebook2/vgg_best.pt] Epoch 44/50 Train Acc: 0.8016 Valid Acc: 0.7965\n",
"[./data/notebook2/vgg_best.pt] Epoch 45/50 Train Acc: 0.8018 Valid Acc: 0.7958\n",
"[./data/notebook2/vgg_best.pt] Epoch 46/50 Train Acc: 0.8021 Valid Acc: 0.7965\n",
"[./data/notebook2/vgg_best.pt] Epoch 47/50 Train Acc: 0.8025 Valid Acc: 0.7966\n",
"[./data/notebook2/vgg_best.pt] Epoch 48/50 Train Acc: 0.8026 Valid Acc: 0.7962\n",
"[./data/notebook2/vgg_best.pt] Epoch 49/50 Train Acc: 0.8025 Valid Acc: 0.7970\n",
"[./data/notebook2/vgg_best.pt] Epoch 50/50 Train Acc: 0.8025 Valid Acc: 0.7970\n",
"\n",
"--- Testing random_qccnn ---\n",
"Test Accuracy: 0.8249\n",
"\n",
"--- Testing qccnn ---\n",
"Test Accuracy: 0.7720\n",
"\n",
"--- Testing vgg ---\n",
"Test Accuracy: 0.7899\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1200x500 with 3 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAHqCAYAAADVi/1VAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAzRBJREFUeJzs3XlYlOX6B/DvLMwM67CvsokiICqKK4plGqZmWplkLqnY8rOTmWZlVifNk9npKGnpOZZGntwy2065RLnvirjvomyy7+sMzLy/PwZGJxZBgYHh+7muuU6888w798upeZj7vZ/7EQmCIICIiIiIiIiIiKgFiY0dABERERERERERtT9MShERERERERERUYtjUoqIiIiIiIiIiFock1JERERERERERNTimJQiIiIiIiIiIqIWx6QUERERERERERG1OCaliIiIiIiIiIioxTEpRURERERERERELY5JKSIiIiIiIiIianFMShE1kI+PD6ZOnWrsMIiIiIiIiIhMApNSRERERERERETU4piUojartLTU2CEQERERERER0X1iUorahA8++AAikQinTp3CuHHjYGdnBz8/P5w8eRLPPvssfHx8YG5uDh8fH0yYMAGJiYkGr4+JiYFIJMKePXvwf//3f3B0dISDgwOeeuop3L5922BsRUUF3nzzTbi6usLCwgKDBg3C8ePHa43r/PnzGDNmDOzs7KBQKBASEoJvvvnGYMzevXshEomwceNGvPXWW3Bzc4OVlRVGjx6NjIwMFBUV4cUXX4SjoyMcHR0xbdo0FBcXN/p3FBMTgy5dukAulyMwMBDr16/H1KlT4ePjYzBOpVJh0aJFCAwMhEKhgIODA4YMGYLDhw/rx2i1WqxcuRIhISEwNzeHra0t+vfvj19++UU/xsfHB48//jh27tyJXr16wdzcHAEBAVi3bt19/+6JiEzRb7/9hpCQEMjlcvj6+uLTTz/Vz2vVGvK5CwAbN27EgAEDYGVlBSsrK4SEhGDt2rX65x9++GEEBwfjxIkTCA8Ph4WFBTp27IiPP/4YWq1WP656btq0aRMWLFgAd3d32NjYYNiwYbhy5Urz/1KIiKjBfvrpJ4hEIvz55581nlu9ejVEIhHOnj0LAPjyyy/h7+8PuVyOoKAgbNy4sdbvBCkpKRg3bhysra1ha2uLiRMn4sSJExCJRIiJiWmBqyLSkRo7AKLGeOqpp/Dss8/i5ZdfRklJCW7duoUuXbrg2Wefhb29PdLS0rB69Wr06dMHFy9ehKOjo8HrZ8yYgVGjRmHjxo1ITk7GvHnzMGnSJOzevVs/5oUXXsD69evxxhtv4NFHH8X58+fx1FNPoaioyOBcV65cQVhYGJydnbFixQo4ODjg22+/xdSpU5GRkYE333zTYPw777yDIUOGICYmBrdu3cIbb7yBCRMmQCqVokePHti0aRPi4+PxzjvvwNraGitWrGjw7yUmJgbTpk3DmDFj8K9//QsFBQX44IMPoFKpIBbfyT1XVlZixIgROHDgAGbPno1HHnkElZWVOHr0KJKSkhAWFgYAmDp1Kr799ltERUVh0aJFkMlkOHXqFG7dumXwvmfOnMHcuXPx9ttvw8XFBV999RWioqLQqVMnDB48uNG/eyIiU/Pnn39izJgxGDBgADZv3gyNRoNPPvkEGRkZBuMa8rn7/vvv48MPP8RTTz2FuXPnQqlU4vz58zVuxKSnp2PixImYO3cu/v73v+PHH3/E/Pnz4e7ujilTphiMfeeddzBw4EB89dVXKCwsxFtvvYXRo0fj0qVLkEgkzfZ7ISKihnv88cfh7OyMr7/+GkOHDjV4LiYmBr169UL37t2xZs0avPTSS3j66aexfPlyFBQUYOHChVCpVAavKSkpwZAhQ5Cbm4ulS5eiU6dO2LlzJyIjI1vysoh0BKI24O9//7sAQHj//ffrHVdZWSkUFxcLlpaWwmeffaY//vXXXwsAhJkzZxqM/+STTwQAQlpamiAIgnDp0iUBgPD6668bjNuwYYMAQHj++ef1x5599llBLpcLSUlJBmNHjBghWFhYCPn5+YIgCMKePXsEAMLo0aMNxs2ePVsAIMyaNcvg+NixYwV7e/t6r/NuGo1GcHd3F3r16iVotVr98Vu3bglmZmaCt7e3/tj69esFAMKXX35Z5/n2798vABAWLFhQ7/t6e3sLCoVCSExM1B8rKysT7O3thZdeekl/rKG/eyIiU9SvXz/B3d1dKCsr0x8rLCwU7O3theo/wxryuZuQkCBIJBJh4sSJ9b7fQw89JAAQjh07ZnA8KChIGD58uP7n6rlp5MiRBuO+++47AYBw5MiRBl8jERE1vzlz5gjm5ub67xiCIAgXL14UAAgrV64UNBqN4OrqKvTr18/gdYmJiTW+E3zxxRcCAGHHjh0GY1966SUBgPD1118356UQGeDyPWpTnn76aYOfi4uL8dZbb6FTp06QSqWQSqWwsrJCSUkJLl26VOP1TzzxhMHP3bt3BwD9XeY9e/YAACZOnGgwbvz48ZBKDQsLd+/ejaFDh8LT09Pg+NSpU1FaWoojR44YHH/88ccNfg4MDAQAjBo1qsbx3NzcBi/hu3LlCm7fvo3nnnvOYCmIt7e3vvKp2o4dO6BQKDB9+vQ6z7djxw4AwCuvvHLP9w4JCYGXl5f+Z4VCAX9//xp37YF7/+6JiExNSUkJTpw4gaeeegoKhUJ/3NraGqNHj9b/3JDP3djYWGg0mgZ9Nru6uqJv374Gx7p3787PZiKiNmz69OkoKyvDli1b9Me+/vpryOVyPPfcc7hy5QrS09Mxfvx4g9d5eXlh4MCBBsf27dsHa2trPPbYYwbHJ0yY0HwXQFQHJqWoTXFzczP4+bnnnsPnn3+OGTNmYNeuXTh+/DhOnDgBJycnlJWV1Xi9g4ODwc9yuRwA9GNzcnIA6P6gv5tUKq3x2pycnBrxAIC7u7vBuarZ29sb/CyTyeo9Xl5eXuPctakr5tqOZWVlwd3d3WBJ319lZWVBIpHUer6/+uvvBND9Tu/nd09EZGry8vKg1Wrv+fnckM/drKwsAECHDh3u+b78bCYiMj1du3ZFnz598PXXXwMANBoNvv32W4wZMwb29vb67wQuLi41XvvXYzk5OQ0aR9QSmJSiNuXuSqCCggL8+uuvePPNN/H2229j6NCh6NOnD7p164bc3Nz7On/1H+fp6ekGxysrK2skmRwcHJCWllbjHNXNu//az6q51BVzbcecnJxw+/Ztg2a3f+Xk5ASNRlPr+YiIqOHs7OwgEonu+fnckM9dJycnALrGtERE1D5NmzYNR48exaVLl7Bz506kpaVh2rRpAO58J/hrz0Kg5ncCBweHBo0jaglMSlGbJRKJIAiC/q5uta+++goajea+zvnwww8DADZs2GBw/LvvvkNlZaXBsaFDh2L37t01dpBbv349LCws0L9///uKobG6dOkCNzc3bNq0CYIg6I8nJiYa7KgHACNGjEB5eXm9O2qMGDECgG4nDyIiun+Wlpbo27cvfvjhB4Pq16KiIvzvf//T/9yQz92IiAhIJBJ+NhMRtWMTJkyAQqFATEwMYmJi4OHhgYiICAC67wSurq747rvvDF6TlJRU4zvBQw89hKKiIv3y8WqbN29u3gsgqgV336M2y8bGBoMHD8Y///lPODo6wsfHB/v27cPatWtha2t7X+cMDAzEpEmTEB0dDTMzMwwbNgznz5/Hp59+ChsbG4Oxf//73/Hrr79iyJAheP/992Fvb48NGzbgt99+wyeffAKlUtkEV3lvYrEYH374IWbMmIEnn3wSL7zwAvLz8/HBBx/UWAoyYcIEfP3113j55Zdx5coVDBkyBFqtFseOHUNgYCCeffZZhIeHY/LkyVi8eDEyMjLw+OOPQy6XIz4+HhYWFnj11Vdb5LqIiEzBhx9+iMceewyPPvoo5s6dC41Gg6VLl8LS0lJf1duQz10fHx+88847+PDDD1FWVoYJEyZAqVTi4sWLyM7OxsKFC418pURE1NxsbW3x5JNPIiYmBvn5+XjjjTf0bTnEYjEWLlyIl156CePGjcP06dORn5+
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Parameter Counts:\n",
"random_qccnn: 232581\n",
"qccnn: 26042\n",
"vgg: 49870\n"
]
}
],
"execution_count": 16
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"下面是新版本脚本相对于原始代码的主要改进点及对应说明,已结合关键代码片段进行标注:\n",
"\n",
"1. **数据增强Data Augmentation**\n",
"\n",
" * **原始**:训练和验证都只做了 `Resize` + `ToTensor`,数据量固定;\n",
" * **改进**:对训练集添加了 `RandomRotation`、`RandomHorizontalFlip`、`RandomVerticalFlip` 等操作,大幅增加了样本多样性,有助于提升模型泛化能力。\n",
"\n",
" ```python\n",
" train_transform = transforms.Compose([\n",
" transforms.Resize((18, 18)),\n",
" transforms.RandomRotation(15), # 随机旋转\n",
" transforms.RandomHorizontalFlip(), # 随机水平翻转\n",
" transforms.RandomVerticalFlip(0.3), # 随机垂直翻转\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
" ])\n",
" ```\n",
"\n",
"2. **优化器与学习率调度**\n",
"\n",
" * **原始**`SGD(lr=0.01, weight_decay=0.001)` 或 `Adam(lr=1e-5)`,无学习率变化;\n",
" * **改进**:统一使用 `AdamW`(带权重衰减的 Adam更稳定添加 `CosineAnnealingLR`,能在训练中动态调整学习率,促进更快收敛、更高精度。\n",
"\n",
" ```python\n",
" optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n",
" ```\n",
"\n",
"3. **批大小与训练轮次Batch Size & Epochs**\n",
"\n",
" * **原始**`batch_size=64``num_epochs=300`\n",
" * **改进**:增大到 `batch_size=128`,减少到 `num_epochs=50`。\n",
"\n",
" * **批量更大**:更好利用 GPU 并行能力;\n",
" * **轮次更少**:缩短训练时间,同时在学习率调度下仍能达到高准确度。\n",
"\n",
"4. **梯度裁剪Gradient Clipping**\n",
"\n",
" * 增加了 `torch.nn.utils.clip_grad_norm_(…)`,避免量子网络中可能出现的梯度爆炸,保障训练稳定性。\n",
"\n",
" ```python\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" ```\n",
"\n",
"5. **统一训练框架与多模型支持**\n",
"\n",
" * 将原来三段几乎重复的训练逻辑,抽象成 `train_model(...)` 函数,并通过一个字典 `models` 循环一次性训练:\n",
"\n",
" * `random_qccnn`、`qccnn`、`vgg` 三种架构统一流程;\n",
" * 每种模型均保存最优权重(`save_path`)与训练曲线(`metrics.csv`)。\n",
"\n",
" ```python\n",
" models = {\n",
" 'random_qccnn': (RandomQCCNN(), 1e-3, './random_qccnn_best.pt'),\n",
" 'qccnn': (QCCNN(), 1e-4, './qccnn_best.pt'),\n",
" 'vgg': (VGG, 1e-4, './vgg_best.pt')\n",
" }\n",
" for name, (model, lr, save_path) in models.items():\n",
" trained, metrics = train_model(\n",
" model, criterion, optimizer, scheduler,\n",
" train_loader, valid_loader,\n",
" num_epochs=50, device=device, save_path=save_path\n",
" )\n",
" pd.DataFrame(metrics).to_csv(f'./{name}_metrics.csv', index=False)\n",
" ```\n",
"\n",
"6. **最佳模型保存和测试分离**\n",
"\n",
" * `train_model` 内部自动监控验证集准确率并保存最佳参数;\n",
" * 训练结束后,再统一加载最佳权重进行测试,确保测试结果与最优状态对应。\n",
"\n",
"7. **可视化对比**\n",
"\n",
" * 最后一个代码块中,通过 `matplotlib` 并排绘制三种模型的验证准确率曲线,直观比较不同模型的收敛速度和最终性能。\n",
"\n",
" ```python\n",
" plt.figure(figsize=(12,5))\n",
" for i,(name,metrics) in enumerate(all_metrics.items(),1):\n",
" plt.subplot(1,3,i)\n",
" plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n",
" plt.title(name); plt.legend()\n",
" plt.tight_layout(); plt.show()\n",
" ```\n",
"\n",
"8. **参数量统计**\n",
"\n",
" * 在脚本末尾增加 `count_parameters` 函数,打印三种模型的可训练参数量,帮助评估模型复杂度与性能的权衡。\n",
"\n",
" ```python\n",
" def count_parameters(m):\n",
" return sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
" for name,(model,_,_) in models.items():\n",
" print(f\"{name}: {count_parameters(model)}\")\n",
" ```\n",
"\n",
"---\n",
"\n",
"**总体效果**\n",
"\n",
"* **训练速度**:批量更大、轮次更少、学习率动态调度,整体训练时间显著缩短。\n",
"* **模型准确率**:数据增强 + 优化器 + 调度器 + 梯度裁剪等多项改进,显著提高了各模型在验证集和测试集上的准确率。\n",
"* **可维护性**:统一框架、函数抽象、循环训练及结果保存,大大简化了代码结构,便于后续扩展和调试。\n"
],
"id": "6a5e9602f481107e"
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}