diff --git a/Modify.ipynb b/Modify.ipynb new file mode 100644 index 0000000..4f9c612 --- /dev/null +++ b/Modify.ipynb @@ -0,0 +1,696 @@ +{ + "cells": [ + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.323528Z", + "start_time": "2025-06-24T17:37:28.321407Z" + } + }, + "cell_type": "code", + "source": "# Modify.py\n", + "id": "98da35e8f6af6b7a", + "outputs": [], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.355899Z", + "start_time": "2025-06-24T17:37:28.352650Z" + } + }, + "cell_type": "code", + "source": [ + "import os\n", + "import random\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import deepquantum as dq\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torchvision.transforms as transforms\n", + "from torchvision.datasets import FashionMNIST\n", + "from tqdm import tqdm\n", + "from torch.utils.data import DataLoader\n", + "from multiprocessing import freeze_support\n" + ], + "id": "fba02718b5ce470f", + "outputs": [], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.376643Z", + "start_time": "2025-06-24T17:37:28.373463Z" + } + }, + "cell_type": "code", + "source": [ + "def seed_torch(seed=1024):\n", + " random.seed(seed)\n", + " os.environ['PYTHONHASHSEED'] = str(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True\n" + ], + "id": "e21c1ebc100b5079", + "outputs": [], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.406416Z", + "start_time": "2025-06-24T17:37:28.403413Z" + } + }, + "cell_type": "code", + "source": [ + "def calculate_score(y_true, y_preds):\n", + " preds_prob = torch.softmax(y_preds, dim=1)\n", + " preds_class = torch.argmax(preds_prob, dim=1)\n", + " correct = (preds_class == y_true).float()\n", + " return (correct.sum() / len(correct)).cpu().numpy()\n" + ], + "id": "f70cc647d264747", + "outputs": [], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.439020Z", + "start_time": "2025-06-24T17:37:28.433485Z" + } + }, + "cell_type": "code", + "source": [ + "def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):\n", + " model.to(device)\n", + " best_acc = 0.0\n", + " metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}\n", + "\n", + " for epoch in range(1, num_epochs + 1):\n", + " # --- 训练阶段 ---\n", + " model.train()\n", + " running_loss, running_acc = 0.0, 0.0\n", + " for imgs, labels in train_loader:\n", + " imgs, labels = imgs.to(device), labels.to(device)\n", + " optimizer.zero_grad()\n", + " outputs = model(imgs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", + " optimizer.step()\n", + " running_loss += loss.item()\n", + " running_acc += calculate_score(labels, outputs)\n", + "\n", + " train_loss = running_loss / len(train_loader)\n", + " train_acc = running_acc / len(train_loader)\n", + " scheduler.step()\n", + "\n", + " # --- 验证阶段 ---\n", + " model.eval()\n", + " val_loss, val_acc = 0.0, 0.0\n", + " with torch.no_grad():\n", + " for imgs, labels in valid_loader:\n", + " imgs, labels = imgs.to(device), labels.to(device)\n", + " outputs = model(imgs)\n", + " loss = criterion(outputs, labels)\n", + " val_loss += loss.item()\n", + " val_acc += calculate_score(labels, outputs)\n", + "\n", + " valid_loss = val_loss / len(valid_loader)\n", + " valid_acc = val_acc / len(valid_loader)\n", + "\n", + " metrics['epoch'].append(epoch)\n", + " metrics['train_loss'].append(train_loss)\n", + " metrics['valid_loss'].append(valid_loss)\n", + " metrics['train_acc'].append(train_acc)\n", + " metrics['valid_acc'].append(valid_acc)\n", + "\n", + " tqdm.write(f\"[{save_path}] Epoch {epoch}/{num_epochs} \"\n", + " f\"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}\")\n", + "\n", + " if valid_acc > best_acc:\n", + " best_acc = valid_acc\n", + " torch.save(model.state_dict(), save_path)\n", + "\n", + " return model, metrics\n" + ], + "id": "7cfe89d63ab0e68e", + "outputs": [], + "execution_count": 13 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.467170Z", + "start_time": "2025-06-24T17:37:28.463678Z" + } + }, + "cell_type": "code", + "source": [ + "def test_model(model, test_loader, device):\n", + " model.to(device).eval()\n", + " acc = 0.0\n", + " with torch.no_grad():\n", + " for imgs, labels in test_loader:\n", + " imgs, labels = imgs.to(device), labels.to(device)\n", + " outputs = model(imgs)\n", + " acc += calculate_score(labels, outputs)\n", + " acc /= len(test_loader)\n", + " print(f\"Test Accuracy: {acc:.4f}\")\n", + " return acc\n" + ], + "id": "235ac16bd786e65f", + "outputs": [], + "execution_count": 14 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T17:37:28.506986Z", + "start_time": "2025-06-24T17:37:28.496003Z" + } + }, + "cell_type": "code", + "source": [ + "singlegate_list = ['rx','ry','rz','s','t','p','u3']\n", + "doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']\n", + "\n", + "class RandomQuantumConvolutionalLayer(nn.Module):\n", + " def __init__(self, nqubit, num_circuits, seed=1024):\n", + " super().__init__()\n", + " random.seed(seed)\n", + " self.nqubit = nqubit\n", + " self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n", + " def circuit(self, nqubit):\n", + " cir = dq.QubitCircuit(nqubit)\n", + " cir.rxlayer(encode=True); cir.barrier()\n", + " for _ in range(3):\n", + " for i in range(nqubit):\n", + " getattr(cir, random.choice(singlegate_list))(i)\n", + " c,t = random.sample(range(nqubit),2)\n", + " gate = random.choice(doublegate_list)\n", + " if gate[0] in ['r','s']:\n", + " getattr(cir, gate)([c,t])\n", + " else:\n", + " getattr(cir, gate)(c,t)\n", + " cir.barrier()\n", + " cir.observable(0)\n", + " return cir\n", + " def forward(self, x):\n", + " k,s = 2,2\n", + " x_unf = x.unfold(2,k,s).unfold(3,k,s)\n", + " w = (x.shape[-1]-k)//s + 1\n", + " x_r = x_unf.reshape(-1, self.nqubit)\n", + " exps = []\n", + " for cir in self.cirs:\n", + " cir(x_r)\n", + " exps.append(cir.expectation())\n", + " exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)\n", + " return exps\n", + "\n", + "class RandomQCCNN(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = nn.Sequential(\n", + " RandomQuantumConvolutionalLayer(4,3,seed=1024),\n", + " nn.ReLU(), nn.MaxPool2d(2,1),\n", + " nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)\n", + " )\n", + " self.fc = nn.Sequential(\n", + " nn.Linear(6*6*6,1024), nn.Dropout(0.4),\n", + " nn.Linear(1024,10)\n", + " )\n", + " def forward(self,x):\n", + " x = self.conv(x)\n", + " x = x.view(x.size(0),-1)\n", + " return self.fc(x)\n", + "\n", + "class ParameterizedQuantumConvolutionalLayer(nn.Module):\n", + " def __init__(self,nqubit,num_circuits):\n", + " super().__init__()\n", + " self.nqubit = nqubit\n", + " self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])\n", + " def circuit(self,nqubit):\n", + " cir = dq.QubitCircuit(nqubit)\n", + " cir.rxlayer(encode=True); cir.barrier()\n", + " for _ in range(4):\n", + " cir.rylayer(); cir.cnot_ring(); cir.barrier()\n", + " cir.observable(0)\n", + " return cir\n", + " def forward(self,x):\n", + " k,s = 2,2\n", + " x_unf = x.unfold(2,k,s).unfold(3,k,s)\n", + " w = (x.shape[-1]-k)//s +1\n", + " x_r = x_unf.reshape(-1,self.nqubit)\n", + " exps = []\n", + " for cir in self.cirs:\n", + " cir(x_r); exps.append(cir.expectation())\n", + " exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)\n", + " return exps\n", + "\n", + "class QCCNN(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = nn.Sequential(\n", + " ParameterizedQuantumConvolutionalLayer(4,3),\n", + " nn.ReLU(), nn.MaxPool2d(2,1)\n", + " )\n", + " self.fc = nn.Sequential(\n", + " nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),\n", + " nn.Linear(128,10)\n", + " )\n", + " def forward(self,x):\n", + " x = self.conv(x); x = x.view(x.size(0),-1)\n", + " return self.fc(x)\n", + "\n", + "def vgg_block(in_c,out_c,n_convs):\n", + " layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]\n", + " for _ in range(n_convs-1):\n", + " layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]\n", + " layers.append(nn.MaxPool2d(2,2))\n", + " return nn.Sequential(*layers)\n", + "\n", + "VGG = nn.Sequential(\n", + " vgg_block(1,10,3),\n", + " vgg_block(10,16,3),\n", + " nn.Flatten(),\n", + " nn.Linear(16*4*4,120), nn.Sigmoid(),\n", + " nn.Linear(120,84), nn.Sigmoid(),\n", + " nn.Linear(84,10), nn.Softmax(dim=-1)\n", + ")\n" + ], + "id": "c996822b3d5f8305", + "outputs": [], + "execution_count": 15 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-24T18:51:44.622880Z", + "start_time": "2025-06-24T17:37:28.529442Z" + } + }, + "cell_type": "code", + "source": [ + "if __name__ == '__main__':\n", + " freeze_support()\n", + "\n", + " # 数据增广与加载\n", + " train_transform = transforms.Compose([\n", + " transforms.Resize((18, 18)),\n", + " transforms.RandomRotation(15),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomVerticalFlip(0.3),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + " ])\n", + " eval_transform = transforms.Compose([\n", + " transforms.Resize((18, 18)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + " ])\n", + "\n", + " full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)\n", + " test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)\n", + " train_size = int(0.8 * len(full_train))\n", + " valid_size = len(full_train) - train_size\n", + " train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])\n", + " valid_ds.dataset.transform = eval_transform\n", + "\n", + " batch_size = 128\n", + " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)\n", + " valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)\n", + " test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)\n", + "\n", + " # 三种模型配置\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " models = {\n", + " 'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),\n", + " 'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),\n", + " 'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')\n", + " }\n", + "\n", + " all_metrics = {}\n", + " for name, (model, lr, save_path) in models.items():\n", + " seed_torch(1024)\n", + " model = model.to(device)\n", + " criterion = nn.CrossEntropyLoss()\n", + " optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)\n", + " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n", + "\n", + " print(f\"\\n=== Training {name} ===\")\n", + " _, metrics = train_model(\n", + " model, criterion, optimizer, scheduler,\n", + " train_loader, valid_loader,\n", + " num_epochs=50, device=device, save_path=save_path\n", + " )\n", + " all_metrics[name] = metrics\n", + " pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)\n", + "\n", + " # 测试与可视化\n", + " plt.figure(figsize=(12,5))\n", + " for i,(name,metrics) in enumerate(all_metrics.items(),1):\n", + " model, _, save_path = models[name]\n", + " best_model = model.to(device)\n", + " best_model.load_state_dict(torch.load(save_path))\n", + " print(f\"\\n--- Testing {name} ---\")\n", + " test_model(best_model, test_loader, device)\n", + "\n", + " plt.subplot(1,3,i)\n", + " plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n", + " plt.xlabel('Epoch'); plt.ylabel('Valid Acc')\n", + " plt.title(name); plt.legend()\n", + "\n", + " plt.tight_layout(); plt.show()\n", + "\n", + " # 参数量统计\n", + " def count_parameters(m):\n", + " return sum(p.numel() for p in m.parameters() if p.requires_grad)\n", + "\n", + " print(\"\\nParameter Counts:\")\n", + " for name,(model,_,_) in models.items():\n", + " print(f\"{name}: {count_parameters(model)}\")\n" + ], + "id": "2d6b93bb78001086", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Training random_qccnn ===\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 1/50 Train Acc: 0.6377 Valid Acc: 0.7426\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 2/50 Train Acc: 0.7589 Valid Acc: 0.7799\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 3/50 Train Acc: 0.7830 Valid Acc: 0.7955\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 4/50 Train Acc: 0.7928 Valid Acc: 0.8010\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 5/50 Train Acc: 0.7997 Valid Acc: 0.8065\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 6/50 Train Acc: 0.8044 Valid Acc: 0.8140\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 7/50 Train Acc: 0.8097 Valid Acc: 0.8157\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 8/50 Train Acc: 0.8155 Valid Acc: 0.8163\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 9/50 Train Acc: 0.8162 Valid Acc: 0.8159\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 10/50 Train Acc: 0.8169 Valid Acc: 0.8160\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 11/50 Train Acc: 0.8210 Valid Acc: 0.8269\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 12/50 Train Acc: 0.8210 Valid Acc: 0.8266\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 13/50 Train Acc: 0.8241 Valid Acc: 0.8212\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 14/50 Train Acc: 0.8240 Valid Acc: 0.8264\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 15/50 Train Acc: 0.8245 Valid Acc: 0.8231\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 16/50 Train Acc: 0.8270 Valid Acc: 0.8291\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 17/50 Train Acc: 0.8274 Valid Acc: 0.8297\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 18/50 Train Acc: 0.8281 Valid Acc: 0.8338\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 19/50 Train Acc: 0.8295 Valid Acc: 0.8291\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 20/50 Train Acc: 0.8306 Valid Acc: 0.8304\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 21/50 Train Acc: 0.8320 Valid Acc: 0.8280\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 22/50 Train Acc: 0.8316 Valid Acc: 0.8293\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 23/50 Train Acc: 0.8315 Valid Acc: 0.8298\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 24/50 Train Acc: 0.8329 Valid Acc: 0.8282\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 25/50 Train Acc: 0.8321 Valid Acc: 0.8313\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 26/50 Train Acc: 0.8350 Valid Acc: 0.8311\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 27/50 Train Acc: 0.8343 Valid Acc: 0.8335\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 28/50 Train Acc: 0.8347 Valid Acc: 0.8320\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 29/50 Train Acc: 0.8354 Valid Acc: 0.8333\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 30/50 Train Acc: 0.8359 Valid Acc: 0.8314\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 31/50 Train Acc: 0.8380 Valid Acc: 0.8348\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 32/50 Train Acc: 0.8369 Valid Acc: 0.8330\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 33/50 Train Acc: 0.8375 Valid Acc: 0.8367\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 34/50 Train Acc: 0.8373 Valid Acc: 0.8315\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 35/50 Train Acc: 0.8381 Valid Acc: 0.8352\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 36/50 Train Acc: 0.8393 Valid Acc: 0.8374\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 37/50 Train Acc: 0.8384 Valid Acc: 0.8348\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 38/50 Train Acc: 0.8398 Valid Acc: 0.8355\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 39/50 Train Acc: 0.8402 Valid Acc: 0.8365\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 40/50 Train Acc: 0.8400 Valid Acc: 0.8346\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 41/50 Train Acc: 0.8411 Valid Acc: 0.8374\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 42/50 Train Acc: 0.8409 Valid Acc: 0.8373\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 43/50 Train Acc: 0.8415 Valid Acc: 0.8364\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 44/50 Train Acc: 0.8414 Valid Acc: 0.8359\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 45/50 Train Acc: 0.8409 Valid Acc: 0.8364\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 46/50 Train Acc: 0.8419 Valid Acc: 0.8371\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 47/50 Train Acc: 0.8422 Valid Acc: 0.8369\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 48/50 Train Acc: 0.8421 Valid Acc: 0.8360\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 49/50 Train Acc: 0.8417 Valid Acc: 0.8369\n", + "[./data/notebook2/random_qccnn_best.pt] Epoch 50/50 Train Acc: 0.8408 Valid Acc: 0.8369\n", + "\n", + "=== Training qccnn ===\n", + "[./data/notebook2/qccnn_best.pt] Epoch 1/50 Train Acc: 0.3113 Valid Acc: 0.5076\n", + "[./data/notebook2/qccnn_best.pt] Epoch 2/50 Train Acc: 0.4621 Valid Acc: 0.5548\n", + "[./data/notebook2/qccnn_best.pt] Epoch 3/50 Train Acc: 0.5201 Valid Acc: 0.5892\n", + "[./data/notebook2/qccnn_best.pt] Epoch 4/50 Train Acc: 0.5642 Valid Acc: 0.6311\n", + "[./data/notebook2/qccnn_best.pt] Epoch 5/50 Train Acc: 0.6056 Valid Acc: 0.6599\n", + "[./data/notebook2/qccnn_best.pt] Epoch 6/50 Train Acc: 0.6317 Valid Acc: 0.6804\n", + "[./data/notebook2/qccnn_best.pt] Epoch 7/50 Train Acc: 0.6514 Valid Acc: 0.6929\n", + "[./data/notebook2/qccnn_best.pt] Epoch 8/50 Train Acc: 0.6678 Valid Acc: 0.7008\n", + "[./data/notebook2/qccnn_best.pt] Epoch 9/50 Train Acc: 0.6809 Valid Acc: 0.7096\n", + "[./data/notebook2/qccnn_best.pt] Epoch 10/50 Train Acc: 0.6877 Valid Acc: 0.7186\n", + "[./data/notebook2/qccnn_best.pt] Epoch 11/50 Train Acc: 0.6976 Valid Acc: 0.7247\n", + "[./data/notebook2/qccnn_best.pt] Epoch 12/50 Train Acc: 0.7031 Valid Acc: 0.7303\n", + "[./data/notebook2/qccnn_best.pt] Epoch 13/50 Train Acc: 0.7119 Valid Acc: 0.7317\n", + "[./data/notebook2/qccnn_best.pt] Epoch 14/50 Train Acc: 0.7164 Valid Acc: 0.7404\n", + "[./data/notebook2/qccnn_best.pt] Epoch 15/50 Train Acc: 0.7211 Valid Acc: 0.7438\n", + "[./data/notebook2/qccnn_best.pt] Epoch 16/50 Train Acc: 0.7249 Valid Acc: 0.7481\n", + "[./data/notebook2/qccnn_best.pt] Epoch 17/50 Train Acc: 0.7294 Valid Acc: 0.7500\n", + "[./data/notebook2/qccnn_best.pt] Epoch 18/50 Train Acc: 0.7345 Valid Acc: 0.7518\n", + "[./data/notebook2/qccnn_best.pt] Epoch 19/50 Train Acc: 0.7350 Valid Acc: 0.7550\n", + "[./data/notebook2/qccnn_best.pt] Epoch 20/50 Train Acc: 0.7391 Valid Acc: 0.7587\n", + "[./data/notebook2/qccnn_best.pt] Epoch 21/50 Train Acc: 0.7434 Valid Acc: 0.7608\n", + "[./data/notebook2/qccnn_best.pt] Epoch 22/50 Train Acc: 0.7443 Valid Acc: 0.7634\n", + "[./data/notebook2/qccnn_best.pt] Epoch 23/50 Train Acc: 0.7481 Valid Acc: 0.7654\n", + "[./data/notebook2/qccnn_best.pt] Epoch 24/50 Train Acc: 0.7498 Valid Acc: 0.7683\n", + "[./data/notebook2/qccnn_best.pt] Epoch 25/50 Train Acc: 0.7529 Valid Acc: 0.7696\n", + "[./data/notebook2/qccnn_best.pt] Epoch 26/50 Train Acc: 0.7547 Valid Acc: 0.7708\n", + "[./data/notebook2/qccnn_best.pt] Epoch 27/50 Train Acc: 0.7547 Valid Acc: 0.7723\n", + "[./data/notebook2/qccnn_best.pt] Epoch 28/50 Train Acc: 0.7580 Valid Acc: 0.7736\n", + "[./data/notebook2/qccnn_best.pt] Epoch 29/50 Train Acc: 0.7571 Valid Acc: 0.7749\n", + "[./data/notebook2/qccnn_best.pt] Epoch 30/50 Train Acc: 0.7602 Valid Acc: 0.7760\n", + "[./data/notebook2/qccnn_best.pt] Epoch 31/50 Train Acc: 0.7610 Valid Acc: 0.7767\n", + "[./data/notebook2/qccnn_best.pt] Epoch 32/50 Train Acc: 0.7618 Valid Acc: 0.7764\n", + "[./data/notebook2/qccnn_best.pt] Epoch 33/50 Train Acc: 0.7630 Valid Acc: 0.7784\n", + "[./data/notebook2/qccnn_best.pt] Epoch 34/50 Train Acc: 0.7632 Valid Acc: 0.7791\n", + "[./data/notebook2/qccnn_best.pt] Epoch 35/50 Train Acc: 0.7627 Valid Acc: 0.7786\n", + "[./data/notebook2/qccnn_best.pt] Epoch 36/50 Train Acc: 0.7653 Valid Acc: 0.7803\n", + "[./data/notebook2/qccnn_best.pt] Epoch 37/50 Train Acc: 0.7640 Valid Acc: 0.7811\n", + "[./data/notebook2/qccnn_best.pt] Epoch 38/50 Train Acc: 0.7674 Valid Acc: 0.7799\n", + "[./data/notebook2/qccnn_best.pt] Epoch 39/50 Train Acc: 0.7649 Valid Acc: 0.7816\n", + "[./data/notebook2/qccnn_best.pt] Epoch 40/50 Train Acc: 0.7661 Valid Acc: 0.7823\n", + "[./data/notebook2/qccnn_best.pt] Epoch 41/50 Train Acc: 0.7668 Valid Acc: 0.7818\n", + "[./data/notebook2/qccnn_best.pt] Epoch 42/50 Train Acc: 0.7662 Valid Acc: 0.7818\n", + "[./data/notebook2/qccnn_best.pt] Epoch 43/50 Train Acc: 0.7668 Valid Acc: 0.7824\n", + "[./data/notebook2/qccnn_best.pt] Epoch 44/50 Train Acc: 0.7678 Valid Acc: 0.7825\n", + "[./data/notebook2/qccnn_best.pt] Epoch 45/50 Train Acc: 0.7677 Valid Acc: 0.7826\n", + "[./data/notebook2/qccnn_best.pt] Epoch 46/50 Train Acc: 0.7666 Valid Acc: 0.7827\n", + "[./data/notebook2/qccnn_best.pt] Epoch 47/50 Train Acc: 0.7689 Valid Acc: 0.7828\n", + "[./data/notebook2/qccnn_best.pt] Epoch 48/50 Train Acc: 0.7675 Valid Acc: 0.7827\n", + "[./data/notebook2/qccnn_best.pt] Epoch 49/50 Train Acc: 0.7678 Valid Acc: 0.7828\n", + "[./data/notebook2/qccnn_best.pt] Epoch 50/50 Train Acc: 0.7677 Valid Acc: 0.7826\n", + "\n", + "=== Training vgg ===\n", + "[./data/notebook2/vgg_best.pt] Epoch 1/50 Train Acc: 0.2536 Valid Acc: 0.4195\n", + "[./data/notebook2/vgg_best.pt] Epoch 2/50 Train Acc: 0.4692 Valid Acc: 0.5073\n", + "[./data/notebook2/vgg_best.pt] Epoch 3/50 Train Acc: 0.5270 Valid Acc: 0.5266\n", + "[./data/notebook2/vgg_best.pt] Epoch 4/50 Train Acc: 0.5354 Valid Acc: 0.5351\n", + "[./data/notebook2/vgg_best.pt] Epoch 5/50 Train Acc: 0.5935 Valid Acc: 0.6175\n", + "[./data/notebook2/vgg_best.pt] Epoch 6/50 Train Acc: 0.6366 Valid Acc: 0.6626\n", + "[./data/notebook2/vgg_best.pt] Epoch 7/50 Train Acc: 0.6866 Valid Acc: 0.7303\n", + "[./data/notebook2/vgg_best.pt] Epoch 8/50 Train Acc: 0.7415 Valid Acc: 0.7512\n", + "[./data/notebook2/vgg_best.pt] Epoch 9/50 Train Acc: 0.7542 Valid Acc: 0.7582\n", + "[./data/notebook2/vgg_best.pt] Epoch 10/50 Train Acc: 0.7608 Valid Acc: 0.7646\n", + "[./data/notebook2/vgg_best.pt] Epoch 11/50 Train Acc: 0.7677 Valid Acc: 0.7719\n", + "[./data/notebook2/vgg_best.pt] Epoch 12/50 Train Acc: 0.7705 Valid Acc: 0.7731\n", + "[./data/notebook2/vgg_best.pt] Epoch 13/50 Train Acc: 0.7716 Valid Acc: 0.7736\n", + "[./data/notebook2/vgg_best.pt] Epoch 14/50 Train Acc: 0.7748 Valid Acc: 0.7749\n", + "[./data/notebook2/vgg_best.pt] Epoch 15/50 Train Acc: 0.7763 Valid Acc: 0.7751\n", + "[./data/notebook2/vgg_best.pt] Epoch 16/50 Train Acc: 0.7790 Valid Acc: 0.7768\n", + "[./data/notebook2/vgg_best.pt] Epoch 17/50 Train Acc: 0.7801 Valid Acc: 0.7802\n", + "[./data/notebook2/vgg_best.pt] Epoch 18/50 Train Acc: 0.7818 Valid Acc: 0.7828\n", + "[./data/notebook2/vgg_best.pt] Epoch 19/50 Train Acc: 0.7827 Valid Acc: 0.7816\n", + "[./data/notebook2/vgg_best.pt] Epoch 20/50 Train Acc: 0.7827 Valid Acc: 0.7840\n", + "[./data/notebook2/vgg_best.pt] Epoch 21/50 Train Acc: 0.7849 Valid Acc: 0.7858\n", + "[./data/notebook2/vgg_best.pt] Epoch 22/50 Train Acc: 0.7877 Valid Acc: 0.7849\n", + "[./data/notebook2/vgg_best.pt] Epoch 23/50 Train Acc: 0.7880 Valid Acc: 0.7849\n", + "[./data/notebook2/vgg_best.pt] Epoch 24/50 Train Acc: 0.7891 Valid Acc: 0.7860\n", + "[./data/notebook2/vgg_best.pt] Epoch 25/50 Train Acc: 0.7903 Valid Acc: 0.7884\n", + "[./data/notebook2/vgg_best.pt] Epoch 26/50 Train Acc: 0.7910 Valid Acc: 0.7900\n", + "[./data/notebook2/vgg_best.pt] Epoch 27/50 Train Acc: 0.7919 Valid Acc: 0.7886\n", + "[./data/notebook2/vgg_best.pt] Epoch 28/50 Train Acc: 0.7937 Valid Acc: 0.7906\n", + "[./data/notebook2/vgg_best.pt] Epoch 29/50 Train Acc: 0.7934 Valid Acc: 0.7876\n", + "[./data/notebook2/vgg_best.pt] Epoch 30/50 Train Acc: 0.7946 Valid Acc: 0.7902\n", + "[./data/notebook2/vgg_best.pt] Epoch 31/50 Train Acc: 0.7959 Valid Acc: 0.7933\n", + "[./data/notebook2/vgg_best.pt] Epoch 32/50 Train Acc: 0.7956 Valid Acc: 0.7933\n", + "[./data/notebook2/vgg_best.pt] Epoch 33/50 Train Acc: 0.7971 Valid Acc: 0.7920\n", + "[./data/notebook2/vgg_best.pt] Epoch 34/50 Train Acc: 0.7974 Valid Acc: 0.7932\n", + "[./data/notebook2/vgg_best.pt] Epoch 35/50 Train Acc: 0.7980 Valid Acc: 0.7948\n", + "[./data/notebook2/vgg_best.pt] Epoch 36/50 Train Acc: 0.7985 Valid Acc: 0.7950\n", + "[./data/notebook2/vgg_best.pt] Epoch 37/50 Train Acc: 0.7990 Valid Acc: 0.7959\n", + "[./data/notebook2/vgg_best.pt] Epoch 38/50 Train Acc: 0.7992 Valid Acc: 0.7949\n", + "[./data/notebook2/vgg_best.pt] Epoch 39/50 Train Acc: 0.8001 Valid Acc: 0.7960\n", + "[./data/notebook2/vgg_best.pt] Epoch 40/50 Train Acc: 0.8006 Valid Acc: 0.7957\n", + "[./data/notebook2/vgg_best.pt] Epoch 41/50 Train Acc: 0.8007 Valid Acc: 0.7963\n", + "[./data/notebook2/vgg_best.pt] Epoch 42/50 Train Acc: 0.8015 Valid Acc: 0.7959\n", + "[./data/notebook2/vgg_best.pt] Epoch 43/50 Train Acc: 0.8014 Valid Acc: 0.7962\n", + "[./data/notebook2/vgg_best.pt] Epoch 44/50 Train Acc: 0.8016 Valid Acc: 0.7965\n", + "[./data/notebook2/vgg_best.pt] Epoch 45/50 Train Acc: 0.8018 Valid Acc: 0.7958\n", + "[./data/notebook2/vgg_best.pt] Epoch 46/50 Train Acc: 0.8021 Valid Acc: 0.7965\n", + "[./data/notebook2/vgg_best.pt] Epoch 47/50 Train Acc: 0.8025 Valid Acc: 0.7966\n", + "[./data/notebook2/vgg_best.pt] Epoch 48/50 Train Acc: 0.8026 Valid Acc: 0.7962\n", + "[./data/notebook2/vgg_best.pt] Epoch 49/50 Train Acc: 0.8025 Valid Acc: 0.7970\n", + "[./data/notebook2/vgg_best.pt] Epoch 50/50 Train Acc: 0.8025 Valid Acc: 0.7970\n", + "\n", + "--- Testing random_qccnn ---\n", + "Test Accuracy: 0.8249\n", + "\n", + "--- Testing qccnn ---\n", + "Test Accuracy: 0.7720\n", + "\n", + "--- Testing vgg ---\n", + "Test Accuracy: 0.7899\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Parameter Counts:\n", + "random_qccnn: 232581\n", + "qccnn: 26042\n", + "vgg: 49870\n" + ] + } + ], + "execution_count": 16 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "下面是新版本脚本相对于原始代码的主要改进点及对应说明,已结合关键代码片段进行标注:\n", + "\n", + "1. **数据增强(Data Augmentation)**\n", + "\n", + " * **原始**:训练和验证都只做了 `Resize` + `ToTensor`,数据量固定;\n", + " * **改进**:对训练集添加了 `RandomRotation`、`RandomHorizontalFlip`、`RandomVerticalFlip` 等操作,大幅增加了样本多样性,有助于提升模型泛化能力。\n", + "\n", + " ```python\n", + " train_transform = transforms.Compose([\n", + " transforms.Resize((18, 18)),\n", + " transforms.RandomRotation(15), # 随机旋转\n", + " transforms.RandomHorizontalFlip(), # 随机水平翻转\n", + " transforms.RandomVerticalFlip(0.3), # 随机垂直翻转\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + " ])\n", + " ```\n", + "\n", + "2. **优化器与学习率调度**\n", + "\n", + " * **原始**:`SGD(lr=0.01, weight_decay=0.001)` 或 `Adam(lr=1e-5)`,无学习率变化;\n", + " * **改进**:统一使用 `AdamW`(带权重衰减的 Adam),更稳定;添加 `CosineAnnealingLR`,能在训练中动态调整学习率,促进更快收敛、更高精度。\n", + "\n", + " ```python\n", + " optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", + " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)\n", + " ```\n", + "\n", + "3. **批大小与训练轮次(Batch Size & Epochs)**\n", + "\n", + " * **原始**:`batch_size=64`,`num_epochs=300`;\n", + " * **改进**:增大到 `batch_size=128`,减少到 `num_epochs=50`。\n", + "\n", + " * **批量更大**:更好利用 GPU 并行能力;\n", + " * **轮次更少**:缩短训练时间,同时在学习率调度下仍能达到高准确度。\n", + "\n", + "4. **梯度裁剪(Gradient Clipping)**\n", + "\n", + " * 增加了 `torch.nn.utils.clip_grad_norm_(…)`,避免量子网络中可能出现的梯度爆炸,保障训练稳定性。\n", + "\n", + " ```python\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", + " ```\n", + "\n", + "5. **统一训练框架与多模型支持**\n", + "\n", + " * 将原来三段几乎重复的训练逻辑,抽象成 `train_model(...)` 函数,并通过一个字典 `models` 循环一次性训练:\n", + "\n", + " * `random_qccnn`、`qccnn`、`vgg` 三种架构统一流程;\n", + " * 每种模型均保存最优权重(`save_path`)与训练曲线(`metrics.csv`)。\n", + "\n", + " ```python\n", + " models = {\n", + " 'random_qccnn': (RandomQCCNN(), 1e-3, './random_qccnn_best.pt'),\n", + " 'qccnn': (QCCNN(), 1e-4, './qccnn_best.pt'),\n", + " 'vgg': (VGG, 1e-4, './vgg_best.pt')\n", + " }\n", + " for name, (model, lr, save_path) in models.items():\n", + " trained, metrics = train_model(\n", + " model, criterion, optimizer, scheduler,\n", + " train_loader, valid_loader,\n", + " num_epochs=50, device=device, save_path=save_path\n", + " )\n", + " pd.DataFrame(metrics).to_csv(f'./{name}_metrics.csv', index=False)\n", + " ```\n", + "\n", + "6. **最佳模型保存和测试分离**\n", + "\n", + " * `train_model` 内部自动监控验证集准确率并保存最佳参数;\n", + " * 训练结束后,再统一加载最佳权重进行测试,确保测试结果与最优状态对应。\n", + "\n", + "7. **可视化对比**\n", + "\n", + " * 最后一个代码块中,通过 `matplotlib` 并排绘制三种模型的验证准确率曲线,直观比较不同模型的收敛速度和最终性能。\n", + "\n", + " ```python\n", + " plt.figure(figsize=(12,5))\n", + " for i,(name,metrics) in enumerate(all_metrics.items(),1):\n", + " plt.subplot(1,3,i)\n", + " plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')\n", + " plt.title(name); plt.legend()\n", + " plt.tight_layout(); plt.show()\n", + " ```\n", + "\n", + "8. **参数量统计**\n", + "\n", + " * 在脚本末尾增加 `count_parameters` 函数,打印三种模型的可训练参数量,帮助评估模型复杂度与性能的权衡。\n", + "\n", + " ```python\n", + " def count_parameters(m):\n", + " return sum(p.numel() for p in m.parameters() if p.requires_grad)\n", + " for name,(model,_,_) in models.items():\n", + " print(f\"{name}: {count_parameters(model)}\")\n", + " ```\n", + "\n", + "---\n", + "\n", + "**总体效果**:\n", + "\n", + "* **训练速度**:批量更大、轮次更少、学习率动态调度,整体训练时间显著缩短。\n", + "* **模型准确率**:数据增强 + 优化器 + 调度器 + 梯度裁剪等多项改进,显著提高了各模型在验证集和测试集上的准确率。\n", + "* **可维护性**:统一框架、函数抽象、循环训练及结果保存,大大简化了代码结构,便于后续扩展和调试。\n" + ], + "id": "6a5e9602f481107e" + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Modify.py b/Modify.py new file mode 100644 index 0000000..dcebaa7 --- /dev/null +++ b/Modify.py @@ -0,0 +1,291 @@ +# Modify.py + +#%% 导入所有需要的包 +import os +import random + +import numpy as np +import pandas as pd +import deepquantum as dq +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.transforms as transforms +from torchvision.datasets import FashionMNIST +from tqdm import tqdm +from torch.utils.data import DataLoader +from multiprocessing import freeze_support + +#%% 设置随机种子以保证可复现 +def seed_torch(seed=1024): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +#%% 准确率计算函数 +def calculate_score(y_true, y_preds): + preds_prob = torch.softmax(y_preds, dim=1) + preds_class = torch.argmax(preds_prob, dim=1) + correct = (preds_class == y_true).float() + return (correct.sum() / len(correct)).cpu().numpy() + +#%% 训练与验证函数 +def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path): + model.to(device) + best_acc = 0.0 + metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []} + + for epoch in range(1, num_epochs + 1): + # --- 训练阶段 --- + model.train() + running_loss, running_acc = 0.0, 0.0 + for imgs, labels in train_loader: + imgs, labels = imgs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(imgs) + loss = criterion(outputs, labels) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + running_loss += loss.item() + running_acc += calculate_score(labels, outputs) + + train_loss = running_loss / len(train_loader) + train_acc = running_acc / len(train_loader) + scheduler.step() + + # --- 验证阶段 --- + model.eval() + val_loss, val_acc = 0.0, 0.0 + with torch.no_grad(): + for imgs, labels in valid_loader: + imgs, labels = imgs.to(device), labels.to(device) + outputs = model(imgs) + loss = criterion(outputs, labels) + val_loss += loss.item() + val_acc += calculate_score(labels, outputs) + + valid_loss = val_loss / len(valid_loader) + valid_acc = val_acc / len(valid_loader) + + metrics['epoch'].append(epoch) + metrics['train_loss'].append(train_loss) + metrics['valid_loss'].append(valid_loss) + metrics['train_acc'].append(train_acc) + metrics['valid_acc'].append(valid_acc) + + tqdm.write(f"[{save_path}] Epoch {epoch}/{num_epochs} " + f"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}") + + if valid_acc > best_acc: + best_acc = valid_acc + torch.save(model.state_dict(), save_path) + + return model, metrics + +#%% 测试函数 +def test_model(model, test_loader, device): + model.to(device).eval() + acc = 0.0 + with torch.no_grad(): + for imgs, labels in test_loader: + imgs, labels = imgs.to(device), labels.to(device) + outputs = model(imgs) + acc += calculate_score(labels, outputs) + acc /= len(test_loader) + print(f"Test Accuracy: {acc:.4f}") + return acc + +#%% 定义量子卷积层与模型 +singlegate_list = ['rx','ry','rz','s','t','p','u3'] +doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz'] + +class RandomQuantumConvolutionalLayer(nn.Module): + def __init__(self, nqubit, num_circuits, seed=1024): + super().__init__() + random.seed(seed) + self.nqubit = nqubit + self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)]) + def circuit(self, nqubit): + cir = dq.QubitCircuit(nqubit) + cir.rxlayer(encode=True); cir.barrier() + for _ in range(3): + for i in range(nqubit): + getattr(cir, random.choice(singlegate_list))(i) + c,t = random.sample(range(nqubit),2) + gate = random.choice(doublegate_list) + if gate[0] in ['r','s']: + getattr(cir, gate)([c,t]) + else: + getattr(cir, gate)(c,t) + cir.barrier() + cir.observable(0) + return cir + def forward(self, x): + k,s = 2,2 + x_unf = x.unfold(2,k,s).unfold(3,k,s) + w = (x.shape[-1]-k)//s + 1 + x_r = x_unf.reshape(-1, self.nqubit) + exps = [] + for cir in self.cirs: + cir(x_r) + exps.append(cir.expectation()) + exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w) + return exps + +class RandomQCCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + RandomQuantumConvolutionalLayer(4,3,seed=1024), + nn.ReLU(), nn.MaxPool2d(2,1), + nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1) + ) + self.fc = nn.Sequential( + nn.Linear(6*6*6,1024), nn.Dropout(0.4), + nn.Linear(1024,10) + ) + def forward(self,x): + x = self.conv(x) + x = x.view(x.size(0),-1) + return self.fc(x) + +class ParameterizedQuantumConvolutionalLayer(nn.Module): + def __init__(self,nqubit,num_circuits): + super().__init__() + self.nqubit = nqubit + self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)]) + def circuit(self,nqubit): + cir = dq.QubitCircuit(nqubit) + cir.rxlayer(encode=True); cir.barrier() + for _ in range(4): + cir.rylayer(); cir.cnot_ring(); cir.barrier() + cir.observable(0) + return cir + def forward(self,x): + k,s = 2,2 + x_unf = x.unfold(2,k,s).unfold(3,k,s) + w = (x.shape[-1]-k)//s +1 + x_r = x_unf.reshape(-1,self.nqubit) + exps = [] + for cir in self.cirs: + cir(x_r); exps.append(cir.expectation()) + exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w) + return exps + +class QCCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + ParameterizedQuantumConvolutionalLayer(4,3), + nn.ReLU(), nn.MaxPool2d(2,1) + ) + self.fc = nn.Sequential( + nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(), + nn.Linear(128,10) + ) + def forward(self,x): + x = self.conv(x); x = x.view(x.size(0),-1) + return self.fc(x) + +def vgg_block(in_c,out_c,n_convs): + layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()] + for _ in range(n_convs-1): + layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()] + layers.append(nn.MaxPool2d(2,2)) + return nn.Sequential(*layers) + +VGG = nn.Sequential( + vgg_block(1,10,3), + vgg_block(10,16,3), + nn.Flatten(), + nn.Linear(16*4*4,120), nn.Sigmoid(), + nn.Linear(120,84), nn.Sigmoid(), + nn.Linear(84,10), nn.Softmax(dim=-1) +) + +#%% 主入口 +if __name__ == '__main__': + freeze_support() + + # 数据增广与加载 + train_transform = transforms.Compose([ + transforms.Resize((18, 18)), + transforms.RandomRotation(15), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(0.3), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + eval_transform = transforms.Compose([ + transforms.Resize((18, 18)), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True) + test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True) + train_size = int(0.8 * len(full_train)) + valid_size = len(full_train) - train_size + train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size]) + valid_ds.dataset.transform = eval_transform + + batch_size = 128 + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4) + + # 三种模型配置 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + models = { + 'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'), + 'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'), + 'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt') + } + + all_metrics = {} + for name, (model, lr, save_path) in models.items(): + seed_torch(1024) + model = model.to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) + + print(f"\n=== Training {name} ===") + _, metrics = train_model( + model, criterion, optimizer, scheduler, + train_loader, valid_loader, + num_epochs=50, device=device, save_path=save_path + ) + all_metrics[name] = metrics + pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False) + + # 测试与可视化 + plt.figure(figsize=(12,5)) + for i,(name,metrics) in enumerate(all_metrics.items(),1): + model, _, save_path = models[name] + best_model = model.to(device) + best_model.load_state_dict(torch.load(save_path)) + print(f"\n--- Testing {name} ---") + test_model(best_model, test_loader, device) + + plt.subplot(1,3,i) + plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc') + plt.xlabel('Epoch'); plt.ylabel('Valid Acc') + plt.title(name); plt.legend() + + plt.tight_layout(); plt.show() + + # 参数量统计 + def count_parameters(m): + return sum(p.numel() for p in m.parameters() if p.requires_grad) + + print("\nParameter Counts:") + for name,(model,_,_) in models.items(): + print(f"{name}: {count_parameters(model)}") diff --git a/Origin.py b/Origin.py new file mode 100644 index 0000000..dc21e0d --- /dev/null +++ b/Origin.py @@ -0,0 +1,460 @@ +#%% +# 首先我们导入所有需要的包: +import os +import random + +import numpy as np +import pandas as pd +import deepquantum as dq +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.transforms as transforms +from tqdm import tqdm +from sklearn.metrics import roc_auc_score +from torch.utils.data import DataLoader +# from torchvision.datasets import MNIST, FashionMNIST + +def seed_torch(seed=1024): + """ + Set random seeds for reproducibility. + + Args: + seed (int): Random seed number to use. Default is 1024. + """ + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Seed all GPUs with the same seed if using multi-GPU + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +seed_torch(1024) +#%% +def calculate_score(y_true, y_preds): + # 将模型预测结果转为概率分布 + preds_prob = torch.softmax(y_preds, dim=1) + # 获得预测的类别(概率最高的一类) + preds_class = torch.argmax(preds_prob, dim=1) + # 计算准确率 + correct = (preds_class == y_true).float() + accuracy = correct.sum() / len(correct) + return accuracy.cpu().numpy() + + +def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device): + """ + 训练和验证模型。 + + Args: + model (torch.nn.Module): 要训练的模型。 + criterion (torch.nn.Module): 损失函数。 + optimizer (torch.optim.Optimizer): 优化器。 + train_loader (torch.utils.data.DataLoader): 训练数据加载器。 + valid_loader (torch.utils.data.DataLoader): 验证数据加载器。 + num_epochs (int): 训练的epoch数。 + + Returns: + model (torch.nn.Module): 训练后的模型。 + """ + + model.train() + train_loss_list = [] + valid_loss_list = [] + train_acc_list = [] + valid_acc_list = [] + + with tqdm(total=num_epochs) as pbar: + for epoch in range(num_epochs): + # 训练阶段 + train_loss = 0.0 + train_acc = 0.0 + for images, labels in train_loader: + images = images.to(device) + labels = labels.to(device) + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + train_acc += calculate_score(labels, outputs) + + train_loss /= len(train_loader) + train_acc /= len(train_loader) + + # 验证阶段 + model.eval() + valid_loss = 0.0 + valid_acc = 0.0 + with torch.no_grad(): + for images, labels in valid_loader: + images = images.to(device) + labels = labels.to(device) + outputs = model(images) + loss = criterion(outputs, labels) + valid_loss += loss.item() + valid_acc += calculate_score(labels, outputs) + + valid_loss /= len(valid_loader) + valid_acc /= len(valid_loader) + + pbar.set_description(f"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}") + pbar.update() + + + train_loss_list.append(train_loss) + valid_loss_list.append(valid_loss) + train_acc_list.append(train_acc) + valid_acc_list.append(valid_acc) + + metrics = {'epoch': list(range(1, num_epochs + 1)), + 'train_acc': train_acc_list, + 'valid_acc': valid_acc_list, + 'train_loss': train_loss_list, + 'valid_loss': valid_loss_list} + + + + return model, metrics + +def test_model(model, test_loader, device): + model.eval() + test_acc = 0.0 + with torch.no_grad(): + for images, labels in test_loader: + images = images.to(device) + labels = labels.to(device) + outputs = model(images) + test_acc += calculate_score(labels, outputs) + + test_acc /= len(test_loader) + print(f'Test Acc: {test_acc:.3f}') + return test_acc +#%% +# 定义图像变换 +trans1 = transforms.Compose([ + transforms.Resize((18, 18)), # 调整大小为18x18 + transforms.ToTensor() # 转换为张量 +]) + +trans2 = transforms.Compose([ + transforms.Resize((16, 16)), # 调整大小为16x16 + transforms.ToTensor() # 转换为张量 +]) +train_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True) +test_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True) + +# 定义训练集和测试集的比例 +train_ratio = 0.8 # 训练集比例为80%,验证集比例为20% +valid_ratio = 0.2 +total_samples = len(train_dataset) +train_size = int(train_ratio * total_samples) +valid_size = int(valid_ratio * total_samples) + +# 分割训练集和测试集 +train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size]) + +# 加载随机抽取的训练数据集 +train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True) +valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True) +test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True) +#%% +singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3'] +doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz'] +#%% +# 随机量子卷积层 +class RandomQuantumConvolutionalLayer(nn.Module): + def __init__(self, nqubit, num_circuits, seed:int=1024): + super(RandomQuantumConvolutionalLayer, self).__init__() + random.seed(seed) + self.nqubit = nqubit + self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)]) + + def circuit(self, nqubit): + cir = dq.QubitCircuit(nqubit) + cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作 + cir.barrier() + for iter in range(3): + for i in range(nqubit): + singlegate = random.choice(singlegate_list) + getattr(cir, singlegate)(i) + control_bit, target_bit = random.sample(range(0, nqubit - 1), 2) + doublegate = random.choice(doublegate_list) + if doublegate[0] in ['r', 's']: + getattr(cir, doublegate)([control_bit, target_bit]) + else: + getattr(cir, doublegate)(control_bit, target_bit) + cir.barrier() + + cir.observable(0) + return cir + + def forward(self, x): + kernel_size, stride = 2, 2 + # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2] + x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride) + w = int((x.shape[-1] - kernel_size) / stride + 1) + x_reshape = x_unflod.reshape(-1, self.nqubit) + + exps = [] + for cir in self.cirs: # out_channels + cir(x_reshape) + exp = cir.expectation() + exps.append(exp) + + exps = torch.stack(exps, dim=1) + exps = exps.reshape(x.shape[0], 3, w, w) + return exps +#%% +net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024) +net.cirs[0].draw() +#%% +# 基于随机量子卷积层的混合模型 +class RandomQCCNN(nn.Module): + def __init__(self): + super(RandomQCCNN, self).__init__() + self.conv = nn.Sequential( + RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核 + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=1), + nn.Conv2d(3, 6, kernel_size=2, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=1) + ) + self.fc = nn.Sequential( + nn.Linear(6 * 6 * 6, 1024), + nn.Dropout(0.4), + nn.Linear(1024, 10) + ) + + def forward(self, x): + x = self.conv(x) + x = x.reshape(x.size(0), -1) + x = self.fc(x) + return x +#%% +num_epochs = 300 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(device) +seed_torch(1024) # 重新设置随机种子 +model = RandomQCCNN() +model.to(device) +criterion = nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加正则化项 +optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device) +torch.save(optim_model.state_dict(), './data/notebook1/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试 +pd.DataFrame(metrics).to_csv('./data/notebook1/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示 +#%% +state_dict = torch.load('./data/notebook1/random_qccnn_weights.pt', map_location=device) +random_qccnn_model = RandomQCCNN() +random_qccnn_model.load_state_dict(state_dict) +random_qccnn_model.to(device) + +test_acc = test_model(random_qccnn_model, test_loader, device) +#%% +data = pd.read_csv('./data/notebook1/random_qccnn_metrics.csv') +epoch = data['epoch'] +train_loss = data['train_loss'] +valid_loss = data['valid_loss'] +train_acc = data['train_acc'] +valid_acc = data['valid_acc'] + +# 创建图和Axes对象 +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + +# 绘制训练损失曲线 +ax1.plot(epoch, train_loss, label='Train Loss') +ax1.plot(epoch, valid_loss, label='Valid Loss') +ax1.set_title('Training Loss Curve') +ax1.set_xlabel('Epoch') +ax1.set_ylabel('Loss') +ax1.legend() + +# 绘制训练准确率曲线 +ax2.plot(epoch, train_acc, label='Train Accuracy') +ax2.plot(epoch, valid_acc, label='Valid Accuracy') +ax2.set_title('Training Accuracy Curve') +ax2.set_xlabel('Epoch') +ax2.set_ylabel('Accuracy') +ax2.legend() + +plt.show() +#%% +class ParameterizedQuantumConvolutionalLayer(nn.Module): + def __init__(self, nqubit, num_circuits): + super().__init__() + self.nqubit = nqubit + self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)]) + + def circuit(self, nqubit): + cir = dq.QubitCircuit(nqubit) + cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作 + cir.barrier() + for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4,可控参数一共16个 + cir.rylayer() + cir.cnot_ring() + cir.barrier() + + cir.observable(0) + return cir + + def forward(self, x): + kernel_size, stride = 2, 2 + # [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2] + x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride) + w = int((x.shape[-1] - kernel_size) / stride + 1) + x_reshape = x_unflod.reshape(-1, self.nqubit) + + exps = [] + for cir in self.cirs: # out_channels + cir(x_reshape) + exp = cir.expectation() + exps.append(exp) + + exps = torch.stack(exps, dim=1) + exps = exps.reshape(x.shape[0], 3, w, w) + return exps +#%% +# 此处我们可视化其中一个量子卷积核的线路结构: +net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3) +net.cirs[0].draw() +#%% +# QCCNN整体网络架构: +class QCCNN(nn.Module): + def __init__(self): + super(QCCNN, self).__init__() + self.conv = nn.Sequential( + ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=1) + ) + + self.fc = nn.Sequential( + nn.Linear(8 * 8 * 3, 128), + nn.Dropout(0.4), + nn.ReLU(), + nn.Linear(128, 10) + ) + + def forward(self, x): + x = self.conv(x) + x = x.reshape(x.size(0), -1) + x = self.fc(x) + return x +#%% +num_epochs = 300 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = QCCNN() +model.to(device) +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=1e-5) # 添加正则化项 +optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device) +torch.save(optim_model.state_dict(), './data/notebook1/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试 +pd.DataFrame(metrics).to_csv('./data/notebook1/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示 +#%% +state_dict = torch.load('./data/notebook1/qccnn_weights.pt', map_location=device) +qccnn_model = QCCNN() +qccnn_model.load_state_dict(state_dict) +qccnn_model.to(device) + +test_acc = test_model(qccnn_model, test_loader, device) +#%% +def vgg_block(in_channel,out_channel,num_convs): + layers = nn.ModuleList() + assert num_convs >= 1 + layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1)) + layers.append(nn.ReLU()) + for _ in range(num_convs-1): + layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1)) + layers.append(nn.ReLU()) + layers.append(nn.MaxPool2d(kernel_size=2,stride=2)) + return nn.Sequential(*layers) + +VGG = nn.Sequential( + vgg_block(1,10,3), # 14,14 + vgg_block(10,16,3), # 4 * 4 + nn.Flatten(), + nn.Linear(16 * 4 * 4, 120), + nn.Sigmoid(), + nn.Linear(120, 84), + nn.Sigmoid(), + nn.Linear(84,10), + nn.Softmax(dim=-1) +) +#%% +num_epochs = 300 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +vgg_model = VGG +vgg_model.to(device) +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(vgg_model.parameters(), lr=1e-5) # 添加正则化项 +vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device) +torch.save(vgg_model.state_dict(), './data/notebook1/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试 +pd.DataFrame(metrics).to_csv('./data/notebook1/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示 +#%% +state_dict = torch.load('./data/notebook1/vgg_weights.pt', map_location=device) +vgg_model = VGG +vgg_model.load_state_dict(state_dict) +vgg_model.to(device) + +vgg_test_acc = test_model(vgg_model, test_loader, device) +#%% +vgg_data = pd.read_csv('./data/notebook1/vgg_metrics.csv') +qccnn_data = pd.read_csv('./data/notebook1/qccnn_metrics.csv') +vgg_epoch = vgg_data['epoch'] +vgg_train_loss = vgg_data['train_loss'] +vgg_valid_loss = vgg_data['valid_loss'] +vgg_train_acc = vgg_data['train_acc'] +vgg_valid_acc = vgg_data['valid_acc'] + +qccnn_epoch = qccnn_data['epoch'] +qccnn_train_loss = qccnn_data['train_loss'] +qccnn_valid_loss = qccnn_data['valid_loss'] +qccnn_train_acc = qccnn_data['train_acc'] +qccnn_valid_acc = qccnn_data['valid_acc'] + +# 创建图和Axes对象 +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + +# 绘制训练损失曲线 +ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss') +ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss') +ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss') +ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss') +ax1.set_title('Training Loss Curve') +ax1.set_xlabel('Epoch') +ax1.set_ylabel('Loss') +ax1.legend() + +# 绘制训练准确率曲线 +ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy') +ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy') +ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy') +ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy') +ax2.set_title('Training Accuracy Curve') +ax2.set_xlabel('Epoch') +ax2.set_ylabel('Accuracy') +ax2.legend() + +plt.show() +#%% +# 这里我们对比不同模型之间可训练参数量的区别 + +def count_parameters(model): + """ + 计算模型的参数数量 + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +number_params_VGG = count_parameters(VGG) +number_params_QCCNN = count_parameters(QCCNN()) +print(f'VGG 模型可训练参数量:{number_params_VGG}\t QCCNN模型可训练参数量:{number_params_QCCNN}') \ No newline at end of file diff --git a/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte b/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..37bac79 Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte differ diff --git a/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz b/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..667844f Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte b/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..2195a4d Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz b/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..abdddb8 Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte b/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..ff2f5a9 Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte differ diff --git a/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz b/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..e6ee0e3 Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte b/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..30424ca Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte differ diff --git a/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz b/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..9c4aae2 Binary files /dev/null and b/data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/data/notebook2/qccnn_best.pt b/data/notebook2/qccnn_best.pt new file mode 100644 index 0000000..9aae44c Binary files /dev/null and b/data/notebook2/qccnn_best.pt differ diff --git a/data/notebook2/qccnn_metrics.csv b/data/notebook2/qccnn_metrics.csv new file mode 100644 index 0000000..4a094fc --- /dev/null +++ b/data/notebook2/qccnn_metrics.csv @@ -0,0 +1,51 @@ +epoch,train_acc,valid_acc,train_loss,valid_loss +1,0.3112916666666667,0.5076444892473119,2.1129255460103353,1.7863053237238238 +2,0.46210416666666665,0.5547715053763441,1.5605127541224162,1.3499009532313193 +3,0.5200833333333333,0.5892137096774194,1.298180630683899,1.1617528520604616 +4,0.5641666666666667,0.6311323924731183,1.161390997727712,1.0419180438082705 +5,0.6055833333333334,0.6598622311827957,1.062113331158956,0.9560088053826363 +6,0.6316666666666667,0.6804435483870968,0.9944041889508565,0.8958990977656457 +7,0.6513958333333333,0.6928763440860215,0.9395476338068645,0.8504296951396491 +8,0.6677708333333333,0.7007728494623656,0.8985106336275737,0.8148919709267155 +9,0.6809375,0.7095934139784946,0.8673048818906148,0.7897684433126962 +10,0.6876666666666666,0.7185819892473119,0.8408271819750468,0.7671139176173877 +11,0.697625,0.7247143817204301,0.8202435002326965,0.7502269674372929 +12,0.7030625,0.7303427419354839,0.8024592914581299,0.7343912464316174 +13,0.7118958333333333,0.7316868279569892,0.7869214735031128,0.7222210591839205 +14,0.7163958333333333,0.7404233870967742,0.7718908907572428,0.7093277560767307 +15,0.7211041666666667,0.7437836021505376,0.7589795832633972,0.6981546148177116 +16,0.724875,0.748067876344086,0.7483940289815267,0.6881407947950465 +17,0.7293958333333334,0.75,0.7373875479698181,0.6788085179944192 +18,0.7345,0.7517641129032258,0.7301613558133443,0.6724013622089099 +19,0.7350416666666667,0.7550403225806451,0.7210039290587107,0.6651292238184201 +20,0.7391458333333333,0.758736559139785,0.711990216255188,0.6584682823509298 +21,0.7433958333333334,0.7608366935483871,0.7047773049672444,0.6513981505106854 +22,0.7442916666666667,0.7634408602150538,0.6997552577654521,0.6454382868864204 +23,0.7481458333333333,0.7653729838709677,0.6937046423753103,0.6401395256160408 +24,0.7498333333333334,0.7683131720430108,0.6873725473880767,0.6349253362865859 +25,0.7529166666666667,0.769573252688172,0.6828897857666015,0.6308066611007977 +26,0.7547291666666667,0.7708333333333334,0.6747500312328338,0.6260690948655528 +27,0.7547291666666667,0.7722614247311828,0.6727884339491527,0.6230032514500362 +28,0.7580208333333334,0.7736055107526881,0.6705719958146413,0.6201616948650729 +29,0.757125,0.7748655913978495,0.6678872625033061,0.6165956912502166 +30,0.76025,0.7759576612903226,0.6641453323364258,0.6141596103227267 +31,0.7609583333333333,0.7767137096774194,0.6619020007451375,0.6114715427480718 +32,0.7618333333333334,0.776377688172043,0.6576849890549977,0.609050533463878 +33,0.7629583333333333,0.7783938172043011,0.6566993356545766,0.6069687149857962 +34,0.7631666666666667,0.7791498655913979,0.6553838003476461,0.6052348261238426 +35,0.7626666666666667,0.7785618279569892,0.6523663277626037,0.6037159392269709 +36,0.7653333333333333,0.780325940860215,0.6526273953914642,0.6023694485105494 +37,0.7639791666666667,0.7810819892473119,0.6524330813884736,0.6012957778669172 +38,0.7673958333333334,0.7799059139784946,0.6500039516290029,0.6002496516191831 +39,0.764875,0.7815860215053764,0.6476710465749105,0.5992967845291219 +40,0.7661041666666667,0.782258064516129,0.6484741485118866,0.5987250285763894 +41,0.7668333333333334,0.7818380376344086,0.64447620677948,0.5979425240588444 +42,0.7662291666666666,0.7817540322580645,0.6458657967249553,0.5976644568545844 +43,0.76675,0.7824260752688172,0.6449048202037811,0.5971484975789183 +44,0.7677916666666667,0.7825100806451613,0.6437487976551056,0.5968063899906733 +45,0.7677083333333333,0.7825940860215054,0.6432626353104909,0.5964810960395361 +46,0.7666041666666666,0.7826780913978495,0.6452286802132925,0.5963560240243071 +47,0.768875,0.7827620967741935,0.6432473388512929,0.5962550809947393 +48,0.7674583333333334,0.7826780913978495,0.6433716455300649,0.5961945524779699 +49,0.7678333333333334,0.7827620967741935,0.6440556212266286,0.596176440036425 +50,0.7677083333333333,0.7825940860215054,0.6438165396849315,0.5961719805835396 diff --git a/data/notebook2/random_qccnn_best.pt b/data/notebook2/random_qccnn_best.pt new file mode 100644 index 0000000..016e9a0 Binary files /dev/null and b/data/notebook2/random_qccnn_best.pt differ diff --git a/data/notebook2/random_qccnn_metrics.csv b/data/notebook2/random_qccnn_metrics.csv new file mode 100644 index 0000000..3050a59 --- /dev/null +++ b/data/notebook2/random_qccnn_metrics.csv @@ -0,0 +1,51 @@ +epoch,train_acc,valid_acc,train_loss,valid_loss +1,0.6377291666666667,0.7426075268817204,1.0222952149709066,0.6658745436899124 +2,0.7588541666666667,0.7799059139784946,0.6436424539883931,0.5876969707909451 +3,0.7829791666666667,0.7955309139784946,0.5888768948713938,0.5508138095178912 +4,0.7927916666666667,0.8009912634408602,0.5641531114578247,0.5473136597423143 +5,0.7997291666666667,0.8065356182795699,0.5444657148520152,0.515406842834206 +6,0.8043541666666667,0.8140120967741935,0.5316140202681223,0.5052012169873843 +7,0.80975,0.8156922043010753,0.521293937365214,0.4983856466508681 +8,0.8155208333333334,0.8162802419354839,0.5079634555180867,0.49958501611986467 +9,0.8161875,0.8159442204301075,0.5036936206022898,0.4907228536503289 +10,0.816875,0.8160282258064516,0.497530157327652,0.4895895427914076 +11,0.8210416666666667,0.8269489247311828,0.4916797243754069,0.47517218673101036 +12,0.8210416666666667,0.8266129032258065,0.4892559293905894,0.468310099135163 +13,0.8241041666666666,0.8211525537634409,0.4825246704419454,0.4856182368852759 +14,0.8239791666666667,0.8264448924731183,0.48390908201535543,0.46895075997998636 +15,0.8245,0.8230846774193549,0.4795244421164195,0.47176380727880746 +16,0.8270416666666667,0.829133064516129,0.47462198424339297,0.4671747069205007 +17,0.8273958333333333,0.8297211021505376,0.4733834793567657,0.4658442559421703 +18,0.828125,0.8337533602150538,0.47195579528808596,0.46035799896845253 +19,0.8294583333333333,0.829133064516129,0.4668528276284536,0.46208705472689804 +20,0.830625,0.8303931451612904,0.46666145197550457,0.4616392642580053 +21,0.8320416666666667,0.8280409946236559,0.46406093080838523,0.4690516353935324 +22,0.8315833333333333,0.8293010752688172,0.4628266294002533,0.4590829178210228 +23,0.8315416666666666,0.8298051075268817,0.46202420779069264,0.4582436110383721 +24,0.8328541666666667,0.8282090053763441,0.4594616918563843,0.46520241454083433 +25,0.8321458333333334,0.8313172043010753,0.45899707794189454,0.4611524093535639 +26,0.835,0.8311491935483871,0.4561348853111267,0.460902286793596 +27,0.8342708333333333,0.8335013440860215,0.45571692689259846,0.4524733103731627 +28,0.83475,0.831989247311828,0.45314634958902994,0.4535403392648184 +29,0.8353958333333333,0.8333333333333334,0.4508692183494568,0.4509869323622796 +30,0.8358958333333333,0.8314012096774194,0.4508490650653839,0.4559444804345408 +31,0.8379791666666667,0.8347614247311828,0.44810916344324747,0.45132114586009775 +32,0.8369375,0.832997311827957,0.4488534228801727,0.4531173186917459 +33,0.8375208333333334,0.8366935483870968,0.4460577855904897,0.4518213977095901 +34,0.8372708333333333,0.8314852150537635,0.4453533016045888,0.4571616124081355 +35,0.8381458333333334,0.8351814516129032,0.4446527805328369,0.4513627043975297 +36,0.8392916666666667,0.8373655913978495,0.4431237142880758,0.44869983228304056 +37,0.838375,0.8347614247311828,0.4424218458334605,0.4486082660895522 +38,0.8397916666666667,0.8355174731182796,0.4412206415732702,0.44703892129723743 +39,0.8401875,0.8365255376344086,0.4396123299598694,0.4470836206149029 +40,0.8399791666666667,0.8345934139784946,0.43971687642733254,0.4466231196157394 +41,0.8410833333333333,0.8373655913978495,0.43810279977321626,0.4461495735312021 +42,0.8408958333333333,0.8372815860215054,0.4371747978528341,0.4456534408113008 +43,0.8414583333333333,0.8363575268817204,0.43774009283383686,0.4454879440287108 +44,0.841375,0.8358534946236559,0.4363077602386475,0.4460145914426414 +45,0.8408958333333333,0.8363575268817204,0.43665687982241314,0.44518788110825325 +46,0.841875,0.8371135752688172,0.43598757874965666,0.44520101848468985 +47,0.8422083333333333,0.836945564516129,0.4349166991710663,0.44520143988311933 +48,0.8421041666666667,0.8360215053763441,0.43520630804697674,0.44510498354511874 +49,0.8416875,0.836861559139785,0.4354708949327469,0.44491737311886204 +50,0.8408125,0.836861559139785,0.43524290529886883,0.4449239748139535 diff --git a/data/notebook2/vgg_best.pt b/data/notebook2/vgg_best.pt new file mode 100644 index 0000000..787171d Binary files /dev/null and b/data/notebook2/vgg_best.pt differ diff --git a/data/notebook2/vgg_metrics.csv b/data/notebook2/vgg_metrics.csv new file mode 100644 index 0000000..a0ee2e0 --- /dev/null +++ b/data/notebook2/vgg_metrics.csv @@ -0,0 +1,51 @@ +epoch,train_acc,valid_acc,train_loss,valid_loss +1,0.2535833333333333,0.41952284946236557,2.274448096593221,2.2001695325297694 +2,0.4691875,0.5073084677419355,2.116311640103658,2.051948335862929 +3,0.5270416666666666,0.5266297043010753,2.009436710357666,1.9777893276624783 +4,0.5353541666666667,0.535114247311828,1.9530798486073813,1.9356736265203005 +5,0.5935416666666666,0.6175235215053764,1.9118746633529664,1.891348486305565 +6,0.636625,0.6625504032258065,1.8629151741663614,1.847881397893352 +7,0.686625,0.7302587365591398,1.8228577140172322,1.8057919330494379 +8,0.7415416666666667,0.7511760752688172,1.7803055089314779,1.7647134245082896 +9,0.7542291666666666,0.7582325268817204,1.7502256110509236,1.7423059594246648 +10,0.7607708333333333,0.764616935483871,1.7317150996526083,1.724781782396378 +11,0.7677291666666667,0.7719254032258065,1.7192926041285197,1.713816902970755 +12,0.7704791666666667,0.7731014784946236,1.71117463239034,1.7070557173862253 +13,0.771625,0.7736055107526881,1.7053865385055542,1.7027398668309695 +14,0.7747708333333333,0.7749495967741935,1.7004834445317587,1.6978983327906618 +15,0.7763333333333333,0.7751176075268817,1.6967166414260864,1.6960316857983988 +16,0.7790416666666666,0.7767977150537635,1.6929608987172444,1.692371742699736 +17,0.780125,0.7801579301075269,1.6910814901987712,1.6917471321680213 +18,0.7817916666666667,0.7827620967741935,1.6882368882497152,1.6861866725388395 +19,0.7827291666666667,0.7815860215053764,1.6857908573150635,1.6868788478195027 +20,0.7827083333333333,0.7840221774193549,1.6843910643259685,1.6850647259784002 +21,0.7848541666666666,0.7857862903225806,1.6823169787724812,1.6813292862266622 +22,0.7876666666666666,0.7848622311827957,1.6801685171127319,1.6815461574062225 +23,0.7879791666666667,0.7848622311827957,1.6788572374979656,1.6818229216401295 +24,0.7890833333333334,0.7859543010752689,1.6776897859573365,1.6799169009731663 +25,0.7902708333333334,0.7883904569892473,1.6763870159784953,1.6777271570697907 +26,0.791,0.789986559139785,1.6751194190979004,1.6777111407249206 +27,0.7918541666666666,0.7886424731182796,1.674490571975708,1.6776156912567795 +28,0.7936666666666666,0.7905745967741935,1.6728278172810873,1.6744540147883917 +29,0.793375,0.7876344086021505,1.671903525352478,1.6776354030896259 +30,0.7945833333333333,0.7901545698924731,1.6710031661987306,1.6751063805754467 +31,0.7959166666666667,0.7932627688172043,1.6700964994430543,1.6719800926023913 +32,0.7955833333333333,0.7933467741935484,1.6697869466145834,1.6723043623790945 +33,0.7971458333333333,0.792002688172043,1.668815224647522,1.6722341519530102 +34,0.7974375,0.7931787634408602,1.6682115157445272,1.6723163063808153 +35,0.7979791666666667,0.7947748655913979,1.6676976483662924,1.670254216399244 +36,0.7984583333333334,0.7950268817204301,1.6669070380528768,1.6704239499184392 +37,0.799,0.795866935483871,1.6664897476832072,1.6691849411174815 +38,0.7991875,0.7948588709677419,1.6660230029424032,1.6700230785595473 +39,0.800125,0.795950940860215,1.6655609397888183,1.6688298884258475 +40,0.8006458333333333,0.7956989247311828,1.6652250595092772,1.6690674840763051 +41,0.8006875,0.7962869623655914,1.664971160888672,1.6688041135828982 +42,0.8015416666666667,0.795866935483871,1.6647087678909303,1.668453808753721 +43,0.8013958333333333,0.7962029569892473,1.6643381754557292,1.6682484457569737 +44,0.8016458333333333,0.7965389784946236,1.6641663211186728,1.668392535178892 +45,0.8017708333333333,0.7957829301075269,1.6640429185231527,1.6681856429705055 +46,0.8021041666666666,0.7965389784946236,1.663849822362264,1.6680005634984663 +47,0.8024583333333334,0.7966229838709677,1.663770879427592,1.6679569816076627 +48,0.8025833333333333,0.7962029569892473,1.6636734215418498,1.6680130330465173 +49,0.8024583333333334,0.7970430107526881,1.6635978577931723,1.6679544077124646 +50,0.8025,0.7970430107526881,1.6635685895284016,1.6679527682642783