697 lines
103 KiB
Plaintext
697 lines
103 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.323528Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.321407Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": "# Modify.py\n",
|
|||
|
"id": "98da35e8f6af6b7a",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 9
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.355899Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.352650Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import random\n",
|
|||
|
"\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import deepquantum as dq\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"import torchvision.transforms as transforms\n",
|
|||
|
"from torchvision.datasets import FashionMNIST\n",
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"from torch.utils.data import DataLoader\n",
|
|||
|
"from multiprocessing import freeze_support\n"
|
|||
|
],
|
|||
|
"id": "fba02718b5ce470f",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 10
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.376643Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.373463Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def seed_torch(seed=1024):\n",
|
|||
|
" random.seed(seed)\n",
|
|||
|
" os.environ['PYTHONHASHSEED'] = str(seed)\n",
|
|||
|
" np.random.seed(seed)\n",
|
|||
|
" torch.manual_seed(seed)\n",
|
|||
|
" torch.cuda.manual_seed(seed)\n",
|
|||
|
" torch.cuda.manual_seed_all(seed)\n",
|
|||
|
" torch.backends.cudnn.benchmark = False\n",
|
|||
|
" torch.backends.cudnn.deterministic = True\n"
|
|||
|
],
|
|||
|
"id": "e21c1ebc100b5079",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 11
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.406416Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.403413Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def calculate_score(y_true, y_preds):\n",
|
|||
|
" preds_prob = torch.softmax(y_preds, dim=1)\n",
|
|||
|
" preds_class = torch.argmax(preds_prob, dim=1)\n",
|
|||
|
" correct = (preds_class == y_true).float()\n",
|
|||
|
" return (correct.sum() / len(correct)).cpu().numpy()\n"
|
|||
|
],
|
|||
|
"id": "f70cc647d264747",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 12
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.439020Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.433485Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):\n",
|
|||
|
" model.to(device)\n",
|
|||
|
" best_acc = 0.0\n",
|
|||
|
" metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}\n",
|
|||
|
"\n",
|
|||
|
" for epoch in range(1, num_epochs + 1):\n",
|
|||
|
" # --- 训练阶段 ---\n",
|
|||
|
" model.train()\n",
|
|||
|
" running_loss, running_acc = 0.0, 0.0\n",
|
|||
|
" for imgs, labels in train_loader:\n",
|
|||
|
" imgs, labels = imgs.to(device), labels.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" outputs = model(imgs)\n",
|
|||
|
" loss = criterion(outputs, labels)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" running_acc += calculate_score(labels, outputs)\n",
|
|||
|
"\n",
|
|||
|
" train_loss = running_loss / len(train_loader)\n",
|
|||
|
" train_acc = running_acc / len(train_loader)\n",
|
|||
|
" scheduler.step()\n",
|
|||
|
"\n",
|
|||
|
" # --- 验证阶段 ---\n",
|
|||
|
" model.eval()\n",
|
|||
|
" val_loss, val_acc = 0.0, 0.0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for imgs, labels in valid_loader:\n",
|
|||
|
" imgs, labels = imgs.to(device), labels.to(device)\n",
|
|||
|
" outputs = model(imgs)\n",
|
|||
|
" loss = criterion(outputs, labels)\n",
|
|||
|
" val_loss += loss.item()\n",
|
|||
|
" val_acc += calculate_score(labels, outputs)\n",
|
|||
|
"\n",
|
|||
|
" valid_loss = val_loss / len(valid_loader)\n",
|
|||
|
" valid_acc = val_acc / len(valid_loader)\n",
|
|||
|
"\n",
|
|||
|
" metrics['epoch'].append(epoch)\n",
|
|||
|
" metrics['train_loss'].append(train_loss)\n",
|
|||
|
" metrics['valid_loss'].append(valid_loss)\n",
|
|||
|
" metrics['train_acc'].append(train_acc)\n",
|
|||
|
" metrics['valid_acc'].append(valid_acc)\n",
|
|||
|
"\n",
|
|||
|
" tqdm.write(f\"[{save_path}] Epoch {epoch}/{num_epochs} \"\n",
|
|||
|
" f\"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}\")\n",
|
|||
|
"\n",
|
|||
|
" if valid_acc > best_acc:\n",
|
|||
|
" best_acc = valid_acc\n",
|
|||
|
" torch.save(model.state_dict(), save_path)\n",
|
|||
|
"\n",
|
|||
|
" return model, metrics\n"
|
|||
|
],
|
|||
|
"id": "7cfe89d63ab0e68e",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 13
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.467170Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.463678Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def test_model(model, test_loader, device):\n",
|
|||
|
" model.to(device).eval()\n",
|
|||
|
" acc = 0.0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for imgs, labels in test_loader:\n",
|
|||
|
" imgs, labels = imgs.to(device), labels.to(device)\n",
|
|||
|
" outputs = model(imgs)\n",
|
|||
|
" acc += calculate_score(labels, outputs)\n",
|
|||
|
" acc /= len(test_loader)\n",
|
|||
|
" print(f\"Test Accuracy: {acc:.4f}\")\n",
|
|||
|
" return acc\n"
|
|||
|
],
|
|||
|
"id": "235ac16bd786e65f",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 14
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T17:37:28.506986Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.496003Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"singlegate_list = ['rx','ry','rz','s','t','p','u3']\n",
|
|||
|
"doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']\n",
|
|||
|
"\n",
|
|||
|
"class RandomQuantumConvolutionalLayer(nn.Module):\n",
|
|||
|
" def __init__(self, nqubit, num_circuits, seed=1024):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" random.seed(seed)\n",
|
|||
|
" self.nqubit = nqubit\n",
|
|||
|
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
|
|||
|
" def circuit(self, nqubit):\n",
|
|||
|
" cir = dq.QubitCircuit(nqubit)\n",
|
|||
|
" cir.rxlayer(encode=True); cir.barrier()\n",
|
|||
|
" for _ in range(3):\n",
|
|||
|
" for i in range(nqubit):\n",
|
|||
|
" getattr(cir, random.choice(singlegate_list))(i)\n",
|
|||
|
" c,t = random.sample(range(nqubit),2)\n",
|
|||
|
" gate = random.choice(doublegate_list)\n",
|
|||
|
" if gate[0] in ['r','s']:\n",
|
|||
|
" getattr(cir, gate)([c,t])\n",
|
|||
|
" else:\n",
|
|||
|
" getattr(cir, gate)(c,t)\n",
|
|||
|
" cir.barrier()\n",
|
|||
|
" cir.observable(0)\n",
|
|||
|
" return cir\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" k,s = 2,2\n",
|
|||
|
" x_unf = x.unfold(2,k,s).unfold(3,k,s)\n",
|
|||
|
" w = (x.shape[-1]-k)//s + 1\n",
|
|||
|
" x_r = x_unf.reshape(-1, self.nqubit)\n",
|
|||
|
" exps = []\n",
|
|||
|
" for cir in self.cirs:\n",
|
|||
|
" cir(x_r)\n",
|
|||
|
" exps.append(cir.expectation())\n",
|
|||
|
" exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)\n",
|
|||
|
" return exps\n",
|
|||
|
"\n",
|
|||
|
"class RandomQCCNN(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.conv = nn.Sequential(\n",
|
|||
|
" RandomQuantumConvolutionalLayer(4,3,seed=1024),\n",
|
|||
|
" nn.ReLU(), nn.MaxPool2d(2,1),\n",
|
|||
|
" nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)\n",
|
|||
|
" )\n",
|
|||
|
" self.fc = nn.Sequential(\n",
|
|||
|
" nn.Linear(6*6*6,1024), nn.Dropout(0.4),\n",
|
|||
|
" nn.Linear(1024,10)\n",
|
|||
|
" )\n",
|
|||
|
" def forward(self,x):\n",
|
|||
|
" x = self.conv(x)\n",
|
|||
|
" x = x.view(x.size(0),-1)\n",
|
|||
|
" return self.fc(x)\n",
|
|||
|
"\n",
|
|||
|
"class ParameterizedQuantumConvolutionalLayer(nn.Module):\n",
|
|||
|
" def __init__(self,nqubit,num_circuits):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.nqubit = nqubit\n",
|
|||
|
" self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n",
|
|||
|
" def circuit(self,nqubit):\n",
|
|||
|
" cir = dq.QubitCircuit(nqubit)\n",
|
|||
|
" cir.rxlayer(encode=True); cir.barrier()\n",
|
|||
|
" for _ in range(4):\n",
|
|||
|
" cir.rylayer(); cir.cnot_ring(); cir.barrier()\n",
|
|||
|
" cir.observable(0)\n",
|
|||
|
" return cir\n",
|
|||
|
" def forward(self,x):\n",
|
|||
|
" k,s = 2,2\n",
|
|||
|
" x_unf = x.unfold(2,k,s).unfold(3,k,s)\n",
|
|||
|
" w = (x.shape[-1]-k)//s +1\n",
|
|||
|
" x_r = x_unf.reshape(-1,self.nqubit)\n",
|
|||
|
" exps = []\n",
|
|||
|
" for cir in self.cirs:\n",
|
|||
|
" cir(x_r); exps.append(cir.expectation())\n",
|
|||
|
" exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)\n",
|
|||
|
" return exps\n",
|
|||
|
"\n",
|
|||
|
"class QCCNN(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.conv = nn.Sequential(\n",
|
|||
|
" ParameterizedQuantumConvolutionalLayer(4,3),\n",
|
|||
|
" nn.ReLU(), nn.MaxPool2d(2,1)\n",
|
|||
|
" )\n",
|
|||
|
" self.fc = nn.Sequential(\n",
|
|||
|
" nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),\n",
|
|||
|
" nn.Linear(128,10)\n",
|
|||
|
" )\n",
|
|||
|
" def forward(self,x):\n",
|
|||
|
" x = self.conv(x); x = x.view(x.size(0),-1)\n",
|
|||
|
" return self.fc(x)\n",
|
|||
|
"\n",
|
|||
|
"def vgg_block(in_c,out_c,n_convs):\n",
|
|||
|
" layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]\n",
|
|||
|
" for _ in range(n_convs-1):\n",
|
|||
|
" layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]\n",
|
|||
|
" layers.append(nn.MaxPool2d(2,2))\n",
|
|||
|
" return nn.Sequential(*layers)\n",
|
|||
|
"\n",
|
|||
|
"VGG = nn.Sequential(\n",
|
|||
|
" vgg_block(1,10,3),\n",
|
|||
|
" vgg_block(10,16,3),\n",
|
|||
|
" nn.Flatten(),\n",
|
|||
|
" nn.Linear(16*4*4,120), nn.Sigmoid(),\n",
|
|||
|
" nn.Linear(120,84), nn.Sigmoid(),\n",
|
|||
|
" nn.Linear(84,10), nn.Softmax(dim=-1)\n",
|
|||
|
")\n"
|
|||
|
],
|
|||
|
"id": "c996822b3d5f8305",
|
|||
|
"outputs": [],
|
|||
|
"execution_count": 15
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {
|
|||
|
"ExecuteTime": {
|
|||
|
"end_time": "2025-06-24T18:51:44.622880Z",
|
|||
|
"start_time": "2025-06-24T17:37:28.529442Z"
|
|||
|
}
|
|||
|
},
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"if __name__ == '__main__':\n",
|
|||
|
" freeze_support()\n",
|
|||
|
"\n",
|
|||
|
" # 数据增广与加载\n",
|
|||
|
" train_transform = transforms.Compose([\n",
|
|||
|
" transforms.Resize((18, 18)),\n",
|
|||
|
" transforms.RandomRotation(15),\n",
|
|||
|
" transforms.RandomHorizontalFlip(),\n",
|
|||
|
" transforms.RandomVerticalFlip(0.3),\n",
|
|||
|
" transforms.ToTensor(),\n",
|
|||
|
" transforms.Normalize((0.5,), (0.5,))\n",
|
|||
|
" ])\n",
|
|||
|
" eval_transform = transforms.Compose([\n",
|
|||
|
" transforms.Resize((18, 18)),\n",
|
|||
|
" transforms.ToTensor(),\n",
|
|||
|
" transforms.Normalize((0.5,), (0.5,))\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
" full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)\n",
|
|||
|
" test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)\n",
|
|||
|
" train_size = int(0.8 * len(full_train))\n",
|
|||
|
" valid_size = len(full_train) - train_size\n",
|
|||
|
" train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])\n",
|
|||
|
" valid_ds.dataset.transform = eval_transform\n",
|
|||
|
"\n",
|
|||
|
" batch_size = 128\n",
|
|||
|
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)\n",
|
|||
|
" valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)\n",
|
|||
|
" test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)\n",
|
|||
|
"\n",
|
|||
|
" # 三种模型配置\n",
|
|||
|
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
" models = {\n",
|
|||
|
" 'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),\n",
|
|||
|
" 'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),\n",
|
|||
|
" 'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" all_metrics = {}\n",
|
|||
|
" for name, (model, lr, save_path) in models.items():\n",
|
|||
|
" seed_torch(1024)\n",
|
|||
|
" model = model.to(device)\n",
|
|||
|
" criterion = nn.CrossEntropyLoss()\n",
|
|||
|
" optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)\n",
|
|||
|
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n",
|
|||
|
"\n",
|
|||
|
" print(f\"\\n=== Training {name} ===\")\n",
|
|||
|
" _, metrics = train_model(\n",
|
|||
|
" model, criterion, optimizer, scheduler,\n",
|
|||
|
" train_loader, valid_loader,\n",
|
|||
|
" num_epochs=50, device=device, save_path=save_path\n",
|
|||
|
" )\n",
|
|||
|
" all_metrics[name] = metrics\n",
|
|||
|
" pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)\n",
|
|||
|
"\n",
|
|||
|
" # 测试与可视化\n",
|
|||
|
" plt.figure(figsize=(12,5))\n",
|
|||
|
" for i,(name,metrics) in enumerate(all_metrics.items(),1):\n",
|
|||
|
" model, _, save_path = models[name]\n",
|
|||
|
" best_model = model.to(device)\n",
|
|||
|
" best_model.load_state_dict(torch.load(save_path))\n",
|
|||
|
" print(f\"\\n--- Testing {name} ---\")\n",
|
|||
|
" test_model(best_model, test_loader, device)\n",
|
|||
|
"\n",
|
|||
|
" plt.subplot(1,3,i)\n",
|
|||
|
" plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n",
|
|||
|
" plt.xlabel('Epoch'); plt.ylabel('Valid Acc')\n",
|
|||
|
" plt.title(name); plt.legend()\n",
|
|||
|
"\n",
|
|||
|
" plt.tight_layout(); plt.show()\n",
|
|||
|
"\n",
|
|||
|
" # 参数量统计\n",
|
|||
|
" def count_parameters(m):\n",
|
|||
|
" return sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
|
|||
|
"\n",
|
|||
|
" print(\"\\nParameter Counts:\")\n",
|
|||
|
" for name,(model,_,_) in models.items():\n",
|
|||
|
" print(f\"{name}: {count_parameters(model)}\")\n"
|
|||
|
],
|
|||
|
"id": "2d6b93bb78001086",
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"=== Training random_qccnn ===\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 1/50 Train Acc: 0.6377 Valid Acc: 0.7426\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 2/50 Train Acc: 0.7589 Valid Acc: 0.7799\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 3/50 Train Acc: 0.7830 Valid Acc: 0.7955\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 4/50 Train Acc: 0.7928 Valid Acc: 0.8010\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 5/50 Train Acc: 0.7997 Valid Acc: 0.8065\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 6/50 Train Acc: 0.8044 Valid Acc: 0.8140\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 7/50 Train Acc: 0.8097 Valid Acc: 0.8157\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 8/50 Train Acc: 0.8155 Valid Acc: 0.8163\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 9/50 Train Acc: 0.8162 Valid Acc: 0.8159\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 10/50 Train Acc: 0.8169 Valid Acc: 0.8160\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 11/50 Train Acc: 0.8210 Valid Acc: 0.8269\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 12/50 Train Acc: 0.8210 Valid Acc: 0.8266\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 13/50 Train Acc: 0.8241 Valid Acc: 0.8212\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 14/50 Train Acc: 0.8240 Valid Acc: 0.8264\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 15/50 Train Acc: 0.8245 Valid Acc: 0.8231\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 16/50 Train Acc: 0.8270 Valid Acc: 0.8291\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 17/50 Train Acc: 0.8274 Valid Acc: 0.8297\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 18/50 Train Acc: 0.8281 Valid Acc: 0.8338\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 19/50 Train Acc: 0.8295 Valid Acc: 0.8291\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 20/50 Train Acc: 0.8306 Valid Acc: 0.8304\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 21/50 Train Acc: 0.8320 Valid Acc: 0.8280\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 22/50 Train Acc: 0.8316 Valid Acc: 0.8293\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 23/50 Train Acc: 0.8315 Valid Acc: 0.8298\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 24/50 Train Acc: 0.8329 Valid Acc: 0.8282\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 25/50 Train Acc: 0.8321 Valid Acc: 0.8313\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 26/50 Train Acc: 0.8350 Valid Acc: 0.8311\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 27/50 Train Acc: 0.8343 Valid Acc: 0.8335\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 28/50 Train Acc: 0.8347 Valid Acc: 0.8320\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 29/50 Train Acc: 0.8354 Valid Acc: 0.8333\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 30/50 Train Acc: 0.8359 Valid Acc: 0.8314\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 31/50 Train Acc: 0.8380 Valid Acc: 0.8348\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 32/50 Train Acc: 0.8369 Valid Acc: 0.8330\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 33/50 Train Acc: 0.8375 Valid Acc: 0.8367\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 34/50 Train Acc: 0.8373 Valid Acc: 0.8315\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 35/50 Train Acc: 0.8381 Valid Acc: 0.8352\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 36/50 Train Acc: 0.8393 Valid Acc: 0.8374\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 37/50 Train Acc: 0.8384 Valid Acc: 0.8348\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 38/50 Train Acc: 0.8398 Valid Acc: 0.8355\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 39/50 Train Acc: 0.8402 Valid Acc: 0.8365\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 40/50 Train Acc: 0.8400 Valid Acc: 0.8346\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 41/50 Train Acc: 0.8411 Valid Acc: 0.8374\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 42/50 Train Acc: 0.8409 Valid Acc: 0.8373\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 43/50 Train Acc: 0.8415 Valid Acc: 0.8364\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 44/50 Train Acc: 0.8414 Valid Acc: 0.8359\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 45/50 Train Acc: 0.8409 Valid Acc: 0.8364\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 46/50 Train Acc: 0.8419 Valid Acc: 0.8371\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 47/50 Train Acc: 0.8422 Valid Acc: 0.8369\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 48/50 Train Acc: 0.8421 Valid Acc: 0.8360\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 49/50 Train Acc: 0.8417 Valid Acc: 0.8369\n",
|
|||
|
"[./data/notebook2/random_qccnn_best.pt] Epoch 50/50 Train Acc: 0.8408 Valid Acc: 0.8369\n",
|
|||
|
"\n",
|
|||
|
"=== Training qccnn ===\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 1/50 Train Acc: 0.3113 Valid Acc: 0.5076\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 2/50 Train Acc: 0.4621 Valid Acc: 0.5548\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 3/50 Train Acc: 0.5201 Valid Acc: 0.5892\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 4/50 Train Acc: 0.5642 Valid Acc: 0.6311\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 5/50 Train Acc: 0.6056 Valid Acc: 0.6599\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 6/50 Train Acc: 0.6317 Valid Acc: 0.6804\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 7/50 Train Acc: 0.6514 Valid Acc: 0.6929\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 8/50 Train Acc: 0.6678 Valid Acc: 0.7008\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 9/50 Train Acc: 0.6809 Valid Acc: 0.7096\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 10/50 Train Acc: 0.6877 Valid Acc: 0.7186\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 11/50 Train Acc: 0.6976 Valid Acc: 0.7247\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 12/50 Train Acc: 0.7031 Valid Acc: 0.7303\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 13/50 Train Acc: 0.7119 Valid Acc: 0.7317\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 14/50 Train Acc: 0.7164 Valid Acc: 0.7404\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 15/50 Train Acc: 0.7211 Valid Acc: 0.7438\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 16/50 Train Acc: 0.7249 Valid Acc: 0.7481\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 17/50 Train Acc: 0.7294 Valid Acc: 0.7500\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 18/50 Train Acc: 0.7345 Valid Acc: 0.7518\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 19/50 Train Acc: 0.7350 Valid Acc: 0.7550\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 20/50 Train Acc: 0.7391 Valid Acc: 0.7587\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 21/50 Train Acc: 0.7434 Valid Acc: 0.7608\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 22/50 Train Acc: 0.7443 Valid Acc: 0.7634\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 23/50 Train Acc: 0.7481 Valid Acc: 0.7654\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 24/50 Train Acc: 0.7498 Valid Acc: 0.7683\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 25/50 Train Acc: 0.7529 Valid Acc: 0.7696\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 26/50 Train Acc: 0.7547 Valid Acc: 0.7708\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 27/50 Train Acc: 0.7547 Valid Acc: 0.7723\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 28/50 Train Acc: 0.7580 Valid Acc: 0.7736\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 29/50 Train Acc: 0.7571 Valid Acc: 0.7749\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 30/50 Train Acc: 0.7602 Valid Acc: 0.7760\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 31/50 Train Acc: 0.7610 Valid Acc: 0.7767\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 32/50 Train Acc: 0.7618 Valid Acc: 0.7764\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 33/50 Train Acc: 0.7630 Valid Acc: 0.7784\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 34/50 Train Acc: 0.7632 Valid Acc: 0.7791\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 35/50 Train Acc: 0.7627 Valid Acc: 0.7786\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 36/50 Train Acc: 0.7653 Valid Acc: 0.7803\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 37/50 Train Acc: 0.7640 Valid Acc: 0.7811\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 38/50 Train Acc: 0.7674 Valid Acc: 0.7799\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 39/50 Train Acc: 0.7649 Valid Acc: 0.7816\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 40/50 Train Acc: 0.7661 Valid Acc: 0.7823\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 41/50 Train Acc: 0.7668 Valid Acc: 0.7818\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 42/50 Train Acc: 0.7662 Valid Acc: 0.7818\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 43/50 Train Acc: 0.7668 Valid Acc: 0.7824\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 44/50 Train Acc: 0.7678 Valid Acc: 0.7825\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 45/50 Train Acc: 0.7677 Valid Acc: 0.7826\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 46/50 Train Acc: 0.7666 Valid Acc: 0.7827\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 47/50 Train Acc: 0.7689 Valid Acc: 0.7828\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 48/50 Train Acc: 0.7675 Valid Acc: 0.7827\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 49/50 Train Acc: 0.7678 Valid Acc: 0.7828\n",
|
|||
|
"[./data/notebook2/qccnn_best.pt] Epoch 50/50 Train Acc: 0.7677 Valid Acc: 0.7826\n",
|
|||
|
"\n",
|
|||
|
"=== Training vgg ===\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 1/50 Train Acc: 0.2536 Valid Acc: 0.4195\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 2/50 Train Acc: 0.4692 Valid Acc: 0.5073\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 3/50 Train Acc: 0.5270 Valid Acc: 0.5266\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 4/50 Train Acc: 0.5354 Valid Acc: 0.5351\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 5/50 Train Acc: 0.5935 Valid Acc: 0.6175\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 6/50 Train Acc: 0.6366 Valid Acc: 0.6626\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 7/50 Train Acc: 0.6866 Valid Acc: 0.7303\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 8/50 Train Acc: 0.7415 Valid Acc: 0.7512\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 9/50 Train Acc: 0.7542 Valid Acc: 0.7582\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 10/50 Train Acc: 0.7608 Valid Acc: 0.7646\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 11/50 Train Acc: 0.7677 Valid Acc: 0.7719\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 12/50 Train Acc: 0.7705 Valid Acc: 0.7731\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 13/50 Train Acc: 0.7716 Valid Acc: 0.7736\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 14/50 Train Acc: 0.7748 Valid Acc: 0.7749\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 15/50 Train Acc: 0.7763 Valid Acc: 0.7751\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 16/50 Train Acc: 0.7790 Valid Acc: 0.7768\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 17/50 Train Acc: 0.7801 Valid Acc: 0.7802\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 18/50 Train Acc: 0.7818 Valid Acc: 0.7828\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 19/50 Train Acc: 0.7827 Valid Acc: 0.7816\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 20/50 Train Acc: 0.7827 Valid Acc: 0.7840\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 21/50 Train Acc: 0.7849 Valid Acc: 0.7858\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 22/50 Train Acc: 0.7877 Valid Acc: 0.7849\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 23/50 Train Acc: 0.7880 Valid Acc: 0.7849\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 24/50 Train Acc: 0.7891 Valid Acc: 0.7860\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 25/50 Train Acc: 0.7903 Valid Acc: 0.7884\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 26/50 Train Acc: 0.7910 Valid Acc: 0.7900\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 27/50 Train Acc: 0.7919 Valid Acc: 0.7886\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 28/50 Train Acc: 0.7937 Valid Acc: 0.7906\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 29/50 Train Acc: 0.7934 Valid Acc: 0.7876\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 30/50 Train Acc: 0.7946 Valid Acc: 0.7902\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 31/50 Train Acc: 0.7959 Valid Acc: 0.7933\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 32/50 Train Acc: 0.7956 Valid Acc: 0.7933\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 33/50 Train Acc: 0.7971 Valid Acc: 0.7920\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 34/50 Train Acc: 0.7974 Valid Acc: 0.7932\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 35/50 Train Acc: 0.7980 Valid Acc: 0.7948\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 36/50 Train Acc: 0.7985 Valid Acc: 0.7950\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 37/50 Train Acc: 0.7990 Valid Acc: 0.7959\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 38/50 Train Acc: 0.7992 Valid Acc: 0.7949\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 39/50 Train Acc: 0.8001 Valid Acc: 0.7960\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 40/50 Train Acc: 0.8006 Valid Acc: 0.7957\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 41/50 Train Acc: 0.8007 Valid Acc: 0.7963\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 42/50 Train Acc: 0.8015 Valid Acc: 0.7959\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 43/50 Train Acc: 0.8014 Valid Acc: 0.7962\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 44/50 Train Acc: 0.8016 Valid Acc: 0.7965\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 45/50 Train Acc: 0.8018 Valid Acc: 0.7958\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 46/50 Train Acc: 0.8021 Valid Acc: 0.7965\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 47/50 Train Acc: 0.8025 Valid Acc: 0.7966\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 48/50 Train Acc: 0.8026 Valid Acc: 0.7962\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 49/50 Train Acc: 0.8025 Valid Acc: 0.7970\n",
|
|||
|
"[./data/notebook2/vgg_best.pt] Epoch 50/50 Train Acc: 0.8025 Valid Acc: 0.7970\n",
|
|||
|
"\n",
|
|||
|
"--- Testing random_qccnn ---\n",
|
|||
|
"Test Accuracy: 0.8249\n",
|
|||
|
"\n",
|
|||
|
"--- Testing qccnn ---\n",
|
|||
|
"Test Accuracy: 0.7720\n",
|
|||
|
"\n",
|
|||
|
"--- Testing vgg ---\n",
|
|||
|
"Test Accuracy: 0.7899\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x500 with 3 Axes>"
|
|||
|
],
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAHqCAYAAADVi/1VAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAzRBJREFUeJzs3XlYlOX6B/DvLMwM67CvsokiICqKK4plGqZmWplkLqnY8rOTmWZlVifNk9npKGnpOZZGntwy2065RLnvirjvomyy7+sMzLy/PwZGJxZBgYHh+7muuU6888w798upeZj7vZ/7EQmCIICIiIiIiIiIiKgFiY0dABERERERERERtT9MShERERERERERUYtjUoqIiIiIiIiIiFock1JERERERERERNTimJQiIiIiIiIiIqIWx6QUERERERERERG1OCaliIiIiIiIiIioxTEpRURERERERERELY5JKSIiIiIiIiIianFMShE1kI+PD6ZOnWrsMIiIiIiIiIhMApNSRERERERERETU4piUojartLTU2CEQERERERER0X1iUorahA8++AAikQinTp3CuHHjYGdnBz8/P5w8eRLPPvssfHx8YG5uDh8fH0yYMAGJiYkGr4+JiYFIJMKePXvwf//3f3B0dISDgwOeeuop3L5922BsRUUF3nzzTbi6usLCwgKDBg3C8ePHa43r/PnzGDNmDOzs7KBQKBASEoJvvvnGYMzevXshEomwceNGvPXWW3Bzc4OVlRVGjx6NjIwMFBUV4cUXX4SjoyMcHR0xbdo0FBcXN/p3FBMTgy5dukAulyMwMBDr16/H1KlT4ePjYzBOpVJh0aJFCAwMhEKhgIODA4YMGYLDhw/rx2i1WqxcuRIhISEwNzeHra0t+vfvj19++UU/xsfHB48//jh27tyJXr16wdzcHAEBAVi3bt19/+6JiEzRb7/9hpCQEMjlcvj6+uLTTz/Vz2vVGvK5CwAbN27EgAEDYGVlBSsrK4SEhGDt2rX65x9++GEEBwfjxIkTCA8Ph4WFBTp27IiPP/4YWq1WP656btq0aRMWLFgAd3d32NjYYNiwYbhy5Urz/1KIiKjBfvrpJ4hEIvz55581nlu9ejVEIhHOnj0LAPjyyy/h7+8PuVyOoKAgbNy4sdbvBCkpKRg3bhysra1ha2uLiRMn4sSJExCJRIiJiWmBqyLSkRo7AKLGeOqpp/Dss8/i5ZdfRklJCW7duoUuXbrg2Wefhb29PdLS0rB69Wr06dMHFy9ehKOjo8HrZ8yYgVGjRmHjxo1ITk7GvHnzMGnSJOzevVs/5oUXXsD69evxxhtv4NFHH8X58+fx1FNPoaioyOBcV65cQVhYGJydnbFixQo4ODjg22+/xdSpU5GRkYE333zTYPw777yDIUOGICYmBrdu3cIbb7yBCRMmQCqVokePHti0aRPi4+PxzjvvwNraGitWrGjw7yUmJgbTpk3DmDFj8K9//QsFBQX44IMPoFKpIBbfyT1XVlZixIgROHDgAGbPno1HHnkElZWVOHr0KJKSkhAWFgYAmDp1Kr799ltERUVh0aJFkMlkOHXqFG7dumXwvmfOnMHcuXPx9ttvw8XFBV999RWioqLQqVMnDB48uNG/eyIiU/Pnn39izJgxGDBgADZv3gyNRoNPPvkEGRkZBuMa8rn7/vvv48MPP8RTTz2FuXPnQqlU4vz58zVuxKSnp2PixImYO3cu/v73v+PHH3/E/Pnz4e7ujilTphiMfeeddzBw4EB89dVXKCwsxFtvvYXRo0fj0qVLkEgkzfZ7ISKihnv88cfh7OyMr7/+GkOHDjV4LiYmBr169UL37t2xZs0avPTSS3j66aexfPlyFBQUYOHChVCpVAavKSkpwZAhQ5Cbm4ulS5eiU6dO2LlzJyIjI1vysoh0BKI24O9//7sAQHj//ffrHVdZWSkUFxcLlpaWwmeffaY//vXXXwsAhJkzZxqM/+STTwQAQlpamiAIgnDp0iUBgPD6668bjNuwYYMAQHj++ef1x5599llBLpcLSUlJBmNHjBghWFhYCPn5+YIgCMKePXsEAMLo0aMNxs2ePVsAIMyaNcvg+NixYwV7e/t6r/NuGo1GcHd3F3r16iVotVr98Vu3bglmZmaCt7e3/tj69esFAMKXX35Z5/n2798vABAWLFhQ7/t6e3sLCoVCSExM1B8rKysT7O3thZdeekl/rKG/eyIiU9SvXz/B3d1dKCsr0x8rLCwU7O3theo/wxryuZuQkCBIJBJh4sSJ9b7fQw89JAAQjh07ZnA8KChIGD58uP7n6rlp5MiRBuO+++47AYBw5MiRBl8jERE1vzlz5gjm5ub67xiCIAgXL14UAAgrV64UNBqN4OrqKvTr18/gdYmJiTW+E3zxxRcCAGHHjh0GY1966SUBgPD1118356UQGeDyPWpTnn76aYOfi4uL8dZbb6FTp06QSqWQSqWwsrJCSUkJLl26VOP1TzzxhMHP3bt3BwD9XeY9e/YAACZOnGgwbvz48ZBKDQsLd+/ejaFDh8LT09Pg+NSpU1FaWoojR44YHH/88ccNfg4MDAQAjBo1qsbx3NzcBi/hu3LlCm7fvo3nnnvOYCmIt7e3vvKp2o4dO6BQKDB9+vQ6z7djxw4AwCuvvHLP9w4JCYGXl5f+Z4VCAX9//xp37YF7/+6JiExNSUkJTpw4gaeeegoKhUJ/3NraGqNHj9b/3JDP3djYWGg0mgZ9Nru6uqJv374Gx7p3787PZiKiNmz69OkoKyvDli1b9Me+/vpryOVyPPfcc7hy5QrS09Mxfvx4g9d5eXlh4MCBBsf27dsHa2trPPbYYwbHJ0yY0HwXQFQHJqWoTXFzczP4+bnnnsPnn3+OGTNmYNeuXTh+/DhOnDgBJycnlJWV1Xi9g4ODwc9yuRwA9GNzcnIA6P6gv5tUKq3x2pycnBrxAIC7u7vBuarZ29sb/CyTyeo9Xl5eXuPctakr5tqOZWVlwd3d3WBJ319lZWVBIpHUer6/+uvvBND9Tu/nd09EZGry8vKg1Wrv+fnckM/drKwsAECHDh3u+b78bCYiMj1du3ZFnz598PXXXwMANBoNvv32W4wZMwb29vb67wQuLi41XvvXYzk5OQ0aR9QSmJSiNuXuSqCCggL8+uuvePPNN/H2229j6NCh6NOnD7p164bc3Nz7On/1H+fp6ekGxysrK2skmRwcHJCWllbjHNXNu//az6q51BVzbcecnJxw+/Ztg2a3f+Xk5ASNRlPr+YiIqOHs7OwgEonu+fnckM9dJycnALrGtERE1D5NmzYNR48exaVLl7Bz506kpaVh2rRpAO58J/hrz0Kg5ncCBweHBo0jaglMSlGbJRKJIAiC/q5uta+++goajea+zvnwww8DADZs2GBw/LvvvkNlZaXBsaFDh2L37t01dpBbv349LCws0L9///uKobG6dOkCNzc3bNq0CYIg6I8nJiYa7KgHACNGjEB5eXm9O2qMGDECgG4nDyIiun+Wlpbo27cvfvjhB4Pq16KiIvzvf//T/9yQz92IiAhIJBJ+NhMRtWMTJkyAQqFATEwMYmJi4OHhgYiICAC67wSurq747rvvDF6TlJRU4zvBQw89hKKiIv3y8WqbN29u3gsgqgV336M2y8bGBoMHD8Y///lPODo6wsfHB/v27cPatWtha2t7X+cMDAzEpEmTEB0dDTMzMwwbNgznz5/Hp59+ChsbG4Oxf//73/Hrr79iyJAheP/992Fvb48NGzbgt99+wyeffAKlUtkEV3lvYrEYH374IWbMmIEnn3wSL7zwAvLz8/HBBx/UWAoyYcIEfP3113j55Zdx5coVDBkyBFqtFseOHUNgYCCeffZZhIeHY/LkyVi8eDEyMjLw+OOPQy6XIz4+HhYWFnj11Vdb5LqIiEzBhx9+iMceewyPPvoo5s6dC41Gg6VLl8LS0lJf1duQz10fHx+88847+PDDD1FWVoYJEyZAqVTi4sWLyM7OxsKFC418pURE1NxsbW3x5JNPIiYmBvn5+XjjjTf0bTnEYjEWLlyIl156CePGjcP06dORn5+
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"Parameter Counts:\n",
|
|||
|
"random_qccnn: 232581\n",
|
|||
|
"qccnn: 26042\n",
|
|||
|
"vgg: 49870\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"execution_count": 16
|
|||
|
},
|
|||
|
{
|
|||
|
"metadata": {},
|
|||
|
"cell_type": "markdown",
|
|||
|
"source": [
|
|||
|
"下面是新版本脚本相对于原始代码的主要改进点及对应说明,已结合关键代码片段进行标注:\n",
|
|||
|
"\n",
|
|||
|
"1. **数据增强(Data Augmentation)**\n",
|
|||
|
"\n",
|
|||
|
" * **原始**:训练和验证都只做了 `Resize` + `ToTensor`,数据量固定;\n",
|
|||
|
" * **改进**:对训练集添加了 `RandomRotation`、`RandomHorizontalFlip`、`RandomVerticalFlip` 等操作,大幅增加了样本多样性,有助于提升模型泛化能力。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" train_transform = transforms.Compose([\n",
|
|||
|
" transforms.Resize((18, 18)),\n",
|
|||
|
" transforms.RandomRotation(15), # 随机旋转\n",
|
|||
|
" transforms.RandomHorizontalFlip(), # 随机水平翻转\n",
|
|||
|
" transforms.RandomVerticalFlip(0.3), # 随机垂直翻转\n",
|
|||
|
" transforms.ToTensor(),\n",
|
|||
|
" transforms.Normalize((0.5,), (0.5,))\n",
|
|||
|
" ])\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"2. **优化器与学习率调度**\n",
|
|||
|
"\n",
|
|||
|
" * **原始**:`SGD(lr=0.01, weight_decay=0.001)` 或 `Adam(lr=1e-5)`,无学习率变化;\n",
|
|||
|
" * **改进**:统一使用 `AdamW`(带权重衰减的 Adam),更稳定;添加 `CosineAnnealingLR`,能在训练中动态调整学习率,促进更快收敛、更高精度。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
|
|||
|
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"3. **批大小与训练轮次(Batch Size & Epochs)**\n",
|
|||
|
"\n",
|
|||
|
" * **原始**:`batch_size=64`,`num_epochs=300`;\n",
|
|||
|
" * **改进**:增大到 `batch_size=128`,减少到 `num_epochs=50`。\n",
|
|||
|
"\n",
|
|||
|
" * **批量更大**:更好利用 GPU 并行能力;\n",
|
|||
|
" * **轮次更少**:缩短训练时间,同时在学习率调度下仍能达到高准确度。\n",
|
|||
|
"\n",
|
|||
|
"4. **梯度裁剪(Gradient Clipping)**\n",
|
|||
|
"\n",
|
|||
|
" * 增加了 `torch.nn.utils.clip_grad_norm_(…)`,避免量子网络中可能出现的梯度爆炸,保障训练稳定性。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"5. **统一训练框架与多模型支持**\n",
|
|||
|
"\n",
|
|||
|
" * 将原来三段几乎重复的训练逻辑,抽象成 `train_model(...)` 函数,并通过一个字典 `models` 循环一次性训练:\n",
|
|||
|
"\n",
|
|||
|
" * `random_qccnn`、`qccnn`、`vgg` 三种架构统一流程;\n",
|
|||
|
" * 每种模型均保存最优权重(`save_path`)与训练曲线(`metrics.csv`)。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" models = {\n",
|
|||
|
" 'random_qccnn': (RandomQCCNN(), 1e-3, './random_qccnn_best.pt'),\n",
|
|||
|
" 'qccnn': (QCCNN(), 1e-4, './qccnn_best.pt'),\n",
|
|||
|
" 'vgg': (VGG, 1e-4, './vgg_best.pt')\n",
|
|||
|
" }\n",
|
|||
|
" for name, (model, lr, save_path) in models.items():\n",
|
|||
|
" trained, metrics = train_model(\n",
|
|||
|
" model, criterion, optimizer, scheduler,\n",
|
|||
|
" train_loader, valid_loader,\n",
|
|||
|
" num_epochs=50, device=device, save_path=save_path\n",
|
|||
|
" )\n",
|
|||
|
" pd.DataFrame(metrics).to_csv(f'./{name}_metrics.csv', index=False)\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"6. **最佳模型保存和测试分离**\n",
|
|||
|
"\n",
|
|||
|
" * `train_model` 内部自动监控验证集准确率并保存最佳参数;\n",
|
|||
|
" * 训练结束后,再统一加载最佳权重进行测试,确保测试结果与最优状态对应。\n",
|
|||
|
"\n",
|
|||
|
"7. **可视化对比**\n",
|
|||
|
"\n",
|
|||
|
" * 最后一个代码块中,通过 `matplotlib` 并排绘制三种模型的验证准确率曲线,直观比较不同模型的收敛速度和最终性能。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" plt.figure(figsize=(12,5))\n",
|
|||
|
" for i,(name,metrics) in enumerate(all_metrics.items(),1):\n",
|
|||
|
" plt.subplot(1,3,i)\n",
|
|||
|
" plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n",
|
|||
|
" plt.title(name); plt.legend()\n",
|
|||
|
" plt.tight_layout(); plt.show()\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"8. **参数量统计**\n",
|
|||
|
"\n",
|
|||
|
" * 在脚本末尾增加 `count_parameters` 函数,打印三种模型的可训练参数量,帮助评估模型复杂度与性能的权衡。\n",
|
|||
|
"\n",
|
|||
|
" ```python\n",
|
|||
|
" def count_parameters(m):\n",
|
|||
|
" return sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
|
|||
|
" for name,(model,_,_) in models.items():\n",
|
|||
|
" print(f\"{name}: {count_parameters(model)}\")\n",
|
|||
|
" ```\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"**总体效果**:\n",
|
|||
|
"\n",
|
|||
|
"* **训练速度**:批量更大、轮次更少、学习率动态调度,整体训练时间显著缩短。\n",
|
|||
|
"* **模型准确率**:数据增强 + 优化器 + 调度器 + 梯度裁剪等多项改进,显著提高了各模型在验证集和测试集上的准确率。\n",
|
|||
|
"* **可维护性**:统一框架、函数抽象、循环训练及结果保存,大大简化了代码结构,便于后续扩展和调试。\n"
|
|||
|
],
|
|||
|
"id": "6a5e9602f481107e"
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|