refactor(Modify): 重构 Modify.py 脚本
- 添加数据增强、优化器与学习率调度等改进 - 统一训练框架,支持多模型训练- 增加梯度裁剪、最佳模型保存等功能 - 优化代码结构,提高可维护性
This commit is contained in:
parent
b1818286f0
commit
a6c92a4031
696
Modify.ipynb
Normal file
696
Modify.ipynb
Normal file
File diff suppressed because one or more lines are too long
291
Modify.py
Normal file
291
Modify.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
# Modify.py
|
||||||
|
|
||||||
|
#%% 导入所有需要的包
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import deepquantum as dq
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torchvision.datasets import FashionMNIST
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from multiprocessing import freeze_support
|
||||||
|
|
||||||
|
#%% 设置随机种子以保证可复现
|
||||||
|
def seed_torch(seed=1024):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
#%% 准确率计算函数
|
||||||
|
def calculate_score(y_true, y_preds):
|
||||||
|
preds_prob = torch.softmax(y_preds, dim=1)
|
||||||
|
preds_class = torch.argmax(preds_prob, dim=1)
|
||||||
|
correct = (preds_class == y_true).float()
|
||||||
|
return (correct.sum() / len(correct)).cpu().numpy()
|
||||||
|
|
||||||
|
#%% 训练与验证函数
|
||||||
|
def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):
|
||||||
|
model.to(device)
|
||||||
|
best_acc = 0.0
|
||||||
|
metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}
|
||||||
|
|
||||||
|
for epoch in range(1, num_epochs + 1):
|
||||||
|
# --- 训练阶段 ---
|
||||||
|
model.train()
|
||||||
|
running_loss, running_acc = 0.0, 0.0
|
||||||
|
for imgs, labels in train_loader:
|
||||||
|
imgs, labels = imgs.to(device), labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(imgs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_acc += calculate_score(labels, outputs)
|
||||||
|
|
||||||
|
train_loss = running_loss / len(train_loader)
|
||||||
|
train_acc = running_acc / len(train_loader)
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# --- 验证阶段 ---
|
||||||
|
model.eval()
|
||||||
|
val_loss, val_acc = 0.0, 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for imgs, labels in valid_loader:
|
||||||
|
imgs, labels = imgs.to(device), labels.to(device)
|
||||||
|
outputs = model(imgs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
val_acc += calculate_score(labels, outputs)
|
||||||
|
|
||||||
|
valid_loss = val_loss / len(valid_loader)
|
||||||
|
valid_acc = val_acc / len(valid_loader)
|
||||||
|
|
||||||
|
metrics['epoch'].append(epoch)
|
||||||
|
metrics['train_loss'].append(train_loss)
|
||||||
|
metrics['valid_loss'].append(valid_loss)
|
||||||
|
metrics['train_acc'].append(train_acc)
|
||||||
|
metrics['valid_acc'].append(valid_acc)
|
||||||
|
|
||||||
|
tqdm.write(f"[{save_path}] Epoch {epoch}/{num_epochs} "
|
||||||
|
f"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}")
|
||||||
|
|
||||||
|
if valid_acc > best_acc:
|
||||||
|
best_acc = valid_acc
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
||||||
|
return model, metrics
|
||||||
|
|
||||||
|
#%% 测试函数
|
||||||
|
def test_model(model, test_loader, device):
|
||||||
|
model.to(device).eval()
|
||||||
|
acc = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for imgs, labels in test_loader:
|
||||||
|
imgs, labels = imgs.to(device), labels.to(device)
|
||||||
|
outputs = model(imgs)
|
||||||
|
acc += calculate_score(labels, outputs)
|
||||||
|
acc /= len(test_loader)
|
||||||
|
print(f"Test Accuracy: {acc:.4f}")
|
||||||
|
return acc
|
||||||
|
|
||||||
|
#%% 定义量子卷积层与模型
|
||||||
|
singlegate_list = ['rx','ry','rz','s','t','p','u3']
|
||||||
|
doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']
|
||||||
|
|
||||||
|
class RandomQuantumConvolutionalLayer(nn.Module):
|
||||||
|
def __init__(self, nqubit, num_circuits, seed=1024):
|
||||||
|
super().__init__()
|
||||||
|
random.seed(seed)
|
||||||
|
self.nqubit = nqubit
|
||||||
|
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||||
|
def circuit(self, nqubit):
|
||||||
|
cir = dq.QubitCircuit(nqubit)
|
||||||
|
cir.rxlayer(encode=True); cir.barrier()
|
||||||
|
for _ in range(3):
|
||||||
|
for i in range(nqubit):
|
||||||
|
getattr(cir, random.choice(singlegate_list))(i)
|
||||||
|
c,t = random.sample(range(nqubit),2)
|
||||||
|
gate = random.choice(doublegate_list)
|
||||||
|
if gate[0] in ['r','s']:
|
||||||
|
getattr(cir, gate)([c,t])
|
||||||
|
else:
|
||||||
|
getattr(cir, gate)(c,t)
|
||||||
|
cir.barrier()
|
||||||
|
cir.observable(0)
|
||||||
|
return cir
|
||||||
|
def forward(self, x):
|
||||||
|
k,s = 2,2
|
||||||
|
x_unf = x.unfold(2,k,s).unfold(3,k,s)
|
||||||
|
w = (x.shape[-1]-k)//s + 1
|
||||||
|
x_r = x_unf.reshape(-1, self.nqubit)
|
||||||
|
exps = []
|
||||||
|
for cir in self.cirs:
|
||||||
|
cir(x_r)
|
||||||
|
exps.append(cir.expectation())
|
||||||
|
exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)
|
||||||
|
return exps
|
||||||
|
|
||||||
|
class RandomQCCNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
RandomQuantumConvolutionalLayer(4,3,seed=1024),
|
||||||
|
nn.ReLU(), nn.MaxPool2d(2,1),
|
||||||
|
nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(6*6*6,1024), nn.Dropout(0.4),
|
||||||
|
nn.Linear(1024,10)
|
||||||
|
)
|
||||||
|
def forward(self,x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x.view(x.size(0),-1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
class ParameterizedQuantumConvolutionalLayer(nn.Module):
|
||||||
|
def __init__(self,nqubit,num_circuits):
|
||||||
|
super().__init__()
|
||||||
|
self.nqubit = nqubit
|
||||||
|
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||||
|
def circuit(self,nqubit):
|
||||||
|
cir = dq.QubitCircuit(nqubit)
|
||||||
|
cir.rxlayer(encode=True); cir.barrier()
|
||||||
|
for _ in range(4):
|
||||||
|
cir.rylayer(); cir.cnot_ring(); cir.barrier()
|
||||||
|
cir.observable(0)
|
||||||
|
return cir
|
||||||
|
def forward(self,x):
|
||||||
|
k,s = 2,2
|
||||||
|
x_unf = x.unfold(2,k,s).unfold(3,k,s)
|
||||||
|
w = (x.shape[-1]-k)//s +1
|
||||||
|
x_r = x_unf.reshape(-1,self.nqubit)
|
||||||
|
exps = []
|
||||||
|
for cir in self.cirs:
|
||||||
|
cir(x_r); exps.append(cir.expectation())
|
||||||
|
exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)
|
||||||
|
return exps
|
||||||
|
|
||||||
|
class QCCNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
ParameterizedQuantumConvolutionalLayer(4,3),
|
||||||
|
nn.ReLU(), nn.MaxPool2d(2,1)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),
|
||||||
|
nn.Linear(128,10)
|
||||||
|
)
|
||||||
|
def forward(self,x):
|
||||||
|
x = self.conv(x); x = x.view(x.size(0),-1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
def vgg_block(in_c,out_c,n_convs):
|
||||||
|
layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]
|
||||||
|
for _ in range(n_convs-1):
|
||||||
|
layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]
|
||||||
|
layers.append(nn.MaxPool2d(2,2))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
VGG = nn.Sequential(
|
||||||
|
vgg_block(1,10,3),
|
||||||
|
vgg_block(10,16,3),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(16*4*4,120), nn.Sigmoid(),
|
||||||
|
nn.Linear(120,84), nn.Sigmoid(),
|
||||||
|
nn.Linear(84,10), nn.Softmax(dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
#%% 主入口
|
||||||
|
if __name__ == '__main__':
|
||||||
|
freeze_support()
|
||||||
|
|
||||||
|
# 数据增广与加载
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((18, 18)),
|
||||||
|
transforms.RandomRotation(15),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomVerticalFlip(0.3),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,))
|
||||||
|
])
|
||||||
|
eval_transform = transforms.Compose([
|
||||||
|
transforms.Resize((18, 18)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,))
|
||||||
|
])
|
||||||
|
|
||||||
|
full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)
|
||||||
|
test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)
|
||||||
|
train_size = int(0.8 * len(full_train))
|
||||||
|
valid_size = len(full_train) - train_size
|
||||||
|
train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])
|
||||||
|
valid_ds.dataset.transform = eval_transform
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||||
|
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
|
||||||
|
|
||||||
|
# 三种模型配置
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
models = {
|
||||||
|
'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),
|
||||||
|
'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),
|
||||||
|
'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')
|
||||||
|
}
|
||||||
|
|
||||||
|
all_metrics = {}
|
||||||
|
for name, (model, lr, save_path) in models.items():
|
||||||
|
seed_torch(1024)
|
||||||
|
model = model.to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
|
||||||
|
|
||||||
|
print(f"\n=== Training {name} ===")
|
||||||
|
_, metrics = train_model(
|
||||||
|
model, criterion, optimizer, scheduler,
|
||||||
|
train_loader, valid_loader,
|
||||||
|
num_epochs=50, device=device, save_path=save_path
|
||||||
|
)
|
||||||
|
all_metrics[name] = metrics
|
||||||
|
pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)
|
||||||
|
|
||||||
|
# 测试与可视化
|
||||||
|
plt.figure(figsize=(12,5))
|
||||||
|
for i,(name,metrics) in enumerate(all_metrics.items(),1):
|
||||||
|
model, _, save_path = models[name]
|
||||||
|
best_model = model.to(device)
|
||||||
|
best_model.load_state_dict(torch.load(save_path))
|
||||||
|
print(f"\n--- Testing {name} ---")
|
||||||
|
test_model(best_model, test_loader, device)
|
||||||
|
|
||||||
|
plt.subplot(1,3,i)
|
||||||
|
plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')
|
||||||
|
plt.xlabel('Epoch'); plt.ylabel('Valid Acc')
|
||||||
|
plt.title(name); plt.legend()
|
||||||
|
|
||||||
|
plt.tight_layout(); plt.show()
|
||||||
|
|
||||||
|
# 参数量统计
|
||||||
|
def count_parameters(m):
|
||||||
|
return sum(p.numel() for p in m.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
print("\nParameter Counts:")
|
||||||
|
for name,(model,_,_) in models.items():
|
||||||
|
print(f"{name}: {count_parameters(model)}")
|
460
Origin.py
Normal file
460
Origin.py
Normal file
@ -0,0 +1,460 @@
|
|||||||
|
#%%
|
||||||
|
# 首先我们导入所有需要的包:
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import deepquantum as dq
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sklearn.metrics import roc_auc_score
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
# from torchvision.datasets import MNIST, FashionMNIST
|
||||||
|
|
||||||
|
def seed_torch(seed=1024):
|
||||||
|
"""
|
||||||
|
Set random seeds for reproducibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Random seed number to use. Default is 1024.
|
||||||
|
"""
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Seed all GPUs with the same seed if using multi-GPU
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
seed_torch(1024)
|
||||||
|
#%%
|
||||||
|
def calculate_score(y_true, y_preds):
|
||||||
|
# 将模型预测结果转为概率分布
|
||||||
|
preds_prob = torch.softmax(y_preds, dim=1)
|
||||||
|
# 获得预测的类别(概率最高的一类)
|
||||||
|
preds_class = torch.argmax(preds_prob, dim=1)
|
||||||
|
# 计算准确率
|
||||||
|
correct = (preds_class == y_true).float()
|
||||||
|
accuracy = correct.sum() / len(correct)
|
||||||
|
return accuracy.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device):
|
||||||
|
"""
|
||||||
|
训练和验证模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): 要训练的模型。
|
||||||
|
criterion (torch.nn.Module): 损失函数。
|
||||||
|
optimizer (torch.optim.Optimizer): 优化器。
|
||||||
|
train_loader (torch.utils.data.DataLoader): 训练数据加载器。
|
||||||
|
valid_loader (torch.utils.data.DataLoader): 验证数据加载器。
|
||||||
|
num_epochs (int): 训练的epoch数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): 训练后的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
train_loss_list = []
|
||||||
|
valid_loss_list = []
|
||||||
|
train_acc_list = []
|
||||||
|
valid_acc_list = []
|
||||||
|
|
||||||
|
with tqdm(total=num_epochs) as pbar:
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# 训练阶段
|
||||||
|
train_loss = 0.0
|
||||||
|
train_acc = 0.0
|
||||||
|
for images, labels in train_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += loss.item()
|
||||||
|
train_acc += calculate_score(labels, outputs)
|
||||||
|
|
||||||
|
train_loss /= len(train_loader)
|
||||||
|
train_acc /= len(train_loader)
|
||||||
|
|
||||||
|
# 验证阶段
|
||||||
|
model.eval()
|
||||||
|
valid_loss = 0.0
|
||||||
|
valid_acc = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, labels in valid_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
outputs = model(images)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
valid_loss += loss.item()
|
||||||
|
valid_acc += calculate_score(labels, outputs)
|
||||||
|
|
||||||
|
valid_loss /= len(valid_loader)
|
||||||
|
valid_acc /= len(valid_loader)
|
||||||
|
|
||||||
|
pbar.set_description(f"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}")
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
|
||||||
|
train_loss_list.append(train_loss)
|
||||||
|
valid_loss_list.append(valid_loss)
|
||||||
|
train_acc_list.append(train_acc)
|
||||||
|
valid_acc_list.append(valid_acc)
|
||||||
|
|
||||||
|
metrics = {'epoch': list(range(1, num_epochs + 1)),
|
||||||
|
'train_acc': train_acc_list,
|
||||||
|
'valid_acc': valid_acc_list,
|
||||||
|
'train_loss': train_loss_list,
|
||||||
|
'valid_loss': valid_loss_list}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return model, metrics
|
||||||
|
|
||||||
|
def test_model(model, test_loader, device):
|
||||||
|
model.eval()
|
||||||
|
test_acc = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, labels in test_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
outputs = model(images)
|
||||||
|
test_acc += calculate_score(labels, outputs)
|
||||||
|
|
||||||
|
test_acc /= len(test_loader)
|
||||||
|
print(f'Test Acc: {test_acc:.3f}')
|
||||||
|
return test_acc
|
||||||
|
#%%
|
||||||
|
# 定义图像变换
|
||||||
|
trans1 = transforms.Compose([
|
||||||
|
transforms.Resize((18, 18)), # 调整大小为18x18
|
||||||
|
transforms.ToTensor() # 转换为张量
|
||||||
|
])
|
||||||
|
|
||||||
|
trans2 = transforms.Compose([
|
||||||
|
transforms.Resize((16, 16)), # 调整大小为16x16
|
||||||
|
transforms.ToTensor() # 转换为张量
|
||||||
|
])
|
||||||
|
train_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
|
||||||
|
test_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
|
||||||
|
|
||||||
|
# 定义训练集和测试集的比例
|
||||||
|
train_ratio = 0.8 # 训练集比例为80%,验证集比例为20%
|
||||||
|
valid_ratio = 0.2
|
||||||
|
total_samples = len(train_dataset)
|
||||||
|
train_size = int(train_ratio * total_samples)
|
||||||
|
valid_size = int(valid_ratio * total_samples)
|
||||||
|
|
||||||
|
# 分割训练集和测试集
|
||||||
|
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])
|
||||||
|
|
||||||
|
# 加载随机抽取的训练数据集
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
|
||||||
|
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)
|
||||||
|
#%%
|
||||||
|
singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3']
|
||||||
|
doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz']
|
||||||
|
#%%
|
||||||
|
# 随机量子卷积层
|
||||||
|
class RandomQuantumConvolutionalLayer(nn.Module):
|
||||||
|
def __init__(self, nqubit, num_circuits, seed:int=1024):
|
||||||
|
super(RandomQuantumConvolutionalLayer, self).__init__()
|
||||||
|
random.seed(seed)
|
||||||
|
self.nqubit = nqubit
|
||||||
|
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||||
|
|
||||||
|
def circuit(self, nqubit):
|
||||||
|
cir = dq.QubitCircuit(nqubit)
|
||||||
|
cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
|
||||||
|
cir.barrier()
|
||||||
|
for iter in range(3):
|
||||||
|
for i in range(nqubit):
|
||||||
|
singlegate = random.choice(singlegate_list)
|
||||||
|
getattr(cir, singlegate)(i)
|
||||||
|
control_bit, target_bit = random.sample(range(0, nqubit - 1), 2)
|
||||||
|
doublegate = random.choice(doublegate_list)
|
||||||
|
if doublegate[0] in ['r', 's']:
|
||||||
|
getattr(cir, doublegate)([control_bit, target_bit])
|
||||||
|
else:
|
||||||
|
getattr(cir, doublegate)(control_bit, target_bit)
|
||||||
|
cir.barrier()
|
||||||
|
|
||||||
|
cir.observable(0)
|
||||||
|
return cir
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
kernel_size, stride = 2, 2
|
||||||
|
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
|
||||||
|
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
|
||||||
|
w = int((x.shape[-1] - kernel_size) / stride + 1)
|
||||||
|
x_reshape = x_unflod.reshape(-1, self.nqubit)
|
||||||
|
|
||||||
|
exps = []
|
||||||
|
for cir in self.cirs: # out_channels
|
||||||
|
cir(x_reshape)
|
||||||
|
exp = cir.expectation()
|
||||||
|
exps.append(exp)
|
||||||
|
|
||||||
|
exps = torch.stack(exps, dim=1)
|
||||||
|
exps = exps.reshape(x.shape[0], 3, w, w)
|
||||||
|
return exps
|
||||||
|
#%%
|
||||||
|
net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024)
|
||||||
|
net.cirs[0].draw()
|
||||||
|
#%%
|
||||||
|
# 基于随机量子卷积层的混合模型
|
||||||
|
class RandomQCCNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(RandomQCCNN, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=1),
|
||||||
|
nn.Conv2d(3, 6, kernel_size=2, stride=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=1)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(6 * 6 * 6, 1024),
|
||||||
|
nn.Dropout(0.4),
|
||||||
|
nn.Linear(1024, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x.reshape(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
#%%
|
||||||
|
num_epochs = 300
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(device)
|
||||||
|
seed_torch(1024) # 重新设置随机种子
|
||||||
|
model = RandomQCCNN()
|
||||||
|
model.to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加正则化项
|
||||||
|
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||||
|
torch.save(optim_model.state_dict(), './data/notebook1/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||||
|
pd.DataFrame(metrics).to_csv('./data/notebook1/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||||
|
#%%
|
||||||
|
state_dict = torch.load('./data/notebook1/random_qccnn_weights.pt', map_location=device)
|
||||||
|
random_qccnn_model = RandomQCCNN()
|
||||||
|
random_qccnn_model.load_state_dict(state_dict)
|
||||||
|
random_qccnn_model.to(device)
|
||||||
|
|
||||||
|
test_acc = test_model(random_qccnn_model, test_loader, device)
|
||||||
|
#%%
|
||||||
|
data = pd.read_csv('./data/notebook1/random_qccnn_metrics.csv')
|
||||||
|
epoch = data['epoch']
|
||||||
|
train_loss = data['train_loss']
|
||||||
|
valid_loss = data['valid_loss']
|
||||||
|
train_acc = data['train_acc']
|
||||||
|
valid_acc = data['valid_acc']
|
||||||
|
|
||||||
|
# 创建图和Axes对象
|
||||||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
||||||
|
|
||||||
|
# 绘制训练损失曲线
|
||||||
|
ax1.plot(epoch, train_loss, label='Train Loss')
|
||||||
|
ax1.plot(epoch, valid_loss, label='Valid Loss')
|
||||||
|
ax1.set_title('Training Loss Curve')
|
||||||
|
ax1.set_xlabel('Epoch')
|
||||||
|
ax1.set_ylabel('Loss')
|
||||||
|
ax1.legend()
|
||||||
|
|
||||||
|
# 绘制训练准确率曲线
|
||||||
|
ax2.plot(epoch, train_acc, label='Train Accuracy')
|
||||||
|
ax2.plot(epoch, valid_acc, label='Valid Accuracy')
|
||||||
|
ax2.set_title('Training Accuracy Curve')
|
||||||
|
ax2.set_xlabel('Epoch')
|
||||||
|
ax2.set_ylabel('Accuracy')
|
||||||
|
ax2.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
#%%
|
||||||
|
class ParameterizedQuantumConvolutionalLayer(nn.Module):
|
||||||
|
def __init__(self, nqubit, num_circuits):
|
||||||
|
super().__init__()
|
||||||
|
self.nqubit = nqubit
|
||||||
|
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||||
|
|
||||||
|
def circuit(self, nqubit):
|
||||||
|
cir = dq.QubitCircuit(nqubit)
|
||||||
|
cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
|
||||||
|
cir.barrier()
|
||||||
|
for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4,可控参数一共16个
|
||||||
|
cir.rylayer()
|
||||||
|
cir.cnot_ring()
|
||||||
|
cir.barrier()
|
||||||
|
|
||||||
|
cir.observable(0)
|
||||||
|
return cir
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
kernel_size, stride = 2, 2
|
||||||
|
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
|
||||||
|
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
|
||||||
|
w = int((x.shape[-1] - kernel_size) / stride + 1)
|
||||||
|
x_reshape = x_unflod.reshape(-1, self.nqubit)
|
||||||
|
|
||||||
|
exps = []
|
||||||
|
for cir in self.cirs: # out_channels
|
||||||
|
cir(x_reshape)
|
||||||
|
exp = cir.expectation()
|
||||||
|
exps.append(exp)
|
||||||
|
|
||||||
|
exps = torch.stack(exps, dim=1)
|
||||||
|
exps = exps.reshape(x.shape[0], 3, w, w)
|
||||||
|
return exps
|
||||||
|
#%%
|
||||||
|
# 此处我们可视化其中一个量子卷积核的线路结构:
|
||||||
|
net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3)
|
||||||
|
net.cirs[0].draw()
|
||||||
|
#%%
|
||||||
|
# QCCNN整体网络架构:
|
||||||
|
class QCCNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(QCCNN, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(8 * 8 * 3, 128),
|
||||||
|
nn.Dropout(0.4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x.reshape(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
#%%
|
||||||
|
num_epochs = 300
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
model = QCCNN()
|
||||||
|
model.to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=1e-5) # 添加正则化项
|
||||||
|
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||||
|
torch.save(optim_model.state_dict(), './data/notebook1/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||||
|
pd.DataFrame(metrics).to_csv('./data/notebook1/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||||
|
#%%
|
||||||
|
state_dict = torch.load('./data/notebook1/qccnn_weights.pt', map_location=device)
|
||||||
|
qccnn_model = QCCNN()
|
||||||
|
qccnn_model.load_state_dict(state_dict)
|
||||||
|
qccnn_model.to(device)
|
||||||
|
|
||||||
|
test_acc = test_model(qccnn_model, test_loader, device)
|
||||||
|
#%%
|
||||||
|
def vgg_block(in_channel,out_channel,num_convs):
|
||||||
|
layers = nn.ModuleList()
|
||||||
|
assert num_convs >= 1
|
||||||
|
layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
for _ in range(num_convs-1):
|
||||||
|
layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
VGG = nn.Sequential(
|
||||||
|
vgg_block(1,10,3), # 14,14
|
||||||
|
vgg_block(10,16,3), # 4 * 4
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(16 * 4 * 4, 120),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
nn.Linear(120, 84),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
nn.Linear(84,10),
|
||||||
|
nn.Softmax(dim=-1)
|
||||||
|
)
|
||||||
|
#%%
|
||||||
|
num_epochs = 300
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
vgg_model = VGG
|
||||||
|
vgg_model.to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(vgg_model.parameters(), lr=1e-5) # 添加正则化项
|
||||||
|
vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||||
|
torch.save(vgg_model.state_dict(), './data/notebook1/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||||
|
pd.DataFrame(metrics).to_csv('./data/notebook1/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||||
|
#%%
|
||||||
|
state_dict = torch.load('./data/notebook1/vgg_weights.pt', map_location=device)
|
||||||
|
vgg_model = VGG
|
||||||
|
vgg_model.load_state_dict(state_dict)
|
||||||
|
vgg_model.to(device)
|
||||||
|
|
||||||
|
vgg_test_acc = test_model(vgg_model, test_loader, device)
|
||||||
|
#%%
|
||||||
|
vgg_data = pd.read_csv('./data/notebook1/vgg_metrics.csv')
|
||||||
|
qccnn_data = pd.read_csv('./data/notebook1/qccnn_metrics.csv')
|
||||||
|
vgg_epoch = vgg_data['epoch']
|
||||||
|
vgg_train_loss = vgg_data['train_loss']
|
||||||
|
vgg_valid_loss = vgg_data['valid_loss']
|
||||||
|
vgg_train_acc = vgg_data['train_acc']
|
||||||
|
vgg_valid_acc = vgg_data['valid_acc']
|
||||||
|
|
||||||
|
qccnn_epoch = qccnn_data['epoch']
|
||||||
|
qccnn_train_loss = qccnn_data['train_loss']
|
||||||
|
qccnn_valid_loss = qccnn_data['valid_loss']
|
||||||
|
qccnn_train_acc = qccnn_data['train_acc']
|
||||||
|
qccnn_valid_acc = qccnn_data['valid_acc']
|
||||||
|
|
||||||
|
# 创建图和Axes对象
|
||||||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
||||||
|
|
||||||
|
# 绘制训练损失曲线
|
||||||
|
ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss')
|
||||||
|
ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss')
|
||||||
|
ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss')
|
||||||
|
ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss')
|
||||||
|
ax1.set_title('Training Loss Curve')
|
||||||
|
ax1.set_xlabel('Epoch')
|
||||||
|
ax1.set_ylabel('Loss')
|
||||||
|
ax1.legend()
|
||||||
|
|
||||||
|
# 绘制训练准确率曲线
|
||||||
|
ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy')
|
||||||
|
ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy')
|
||||||
|
ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy')
|
||||||
|
ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy')
|
||||||
|
ax2.set_title('Training Accuracy Curve')
|
||||||
|
ax2.set_xlabel('Epoch')
|
||||||
|
ax2.set_ylabel('Accuracy')
|
||||||
|
ax2.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
#%%
|
||||||
|
# 这里我们对比不同模型之间可训练参数量的区别
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
"""
|
||||||
|
计算模型的参数数量
|
||||||
|
"""
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
number_params_VGG = count_parameters(VGG)
|
||||||
|
number_params_QCCNN = count_parameters(QCCNN())
|
||||||
|
print(f'VGG 模型可训练参数量:{number_params_VGG}\t QCCNN模型可训练参数量:{number_params_QCCNN}')
|
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/qccnn_best.pt
Normal file
BIN
data/notebook2/qccnn_best.pt
Normal file
Binary file not shown.
51
data/notebook2/qccnn_metrics.csv
Normal file
51
data/notebook2/qccnn_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||||
|
1,0.3112916666666667,0.5076444892473119,2.1129255460103353,1.7863053237238238
|
||||||
|
2,0.46210416666666665,0.5547715053763441,1.5605127541224162,1.3499009532313193
|
||||||
|
3,0.5200833333333333,0.5892137096774194,1.298180630683899,1.1617528520604616
|
||||||
|
4,0.5641666666666667,0.6311323924731183,1.161390997727712,1.0419180438082705
|
||||||
|
5,0.6055833333333334,0.6598622311827957,1.062113331158956,0.9560088053826363
|
||||||
|
6,0.6316666666666667,0.6804435483870968,0.9944041889508565,0.8958990977656457
|
||||||
|
7,0.6513958333333333,0.6928763440860215,0.9395476338068645,0.8504296951396491
|
||||||
|
8,0.6677708333333333,0.7007728494623656,0.8985106336275737,0.8148919709267155
|
||||||
|
9,0.6809375,0.7095934139784946,0.8673048818906148,0.7897684433126962
|
||||||
|
10,0.6876666666666666,0.7185819892473119,0.8408271819750468,0.7671139176173877
|
||||||
|
11,0.697625,0.7247143817204301,0.8202435002326965,0.7502269674372929
|
||||||
|
12,0.7030625,0.7303427419354839,0.8024592914581299,0.7343912464316174
|
||||||
|
13,0.7118958333333333,0.7316868279569892,0.7869214735031128,0.7222210591839205
|
||||||
|
14,0.7163958333333333,0.7404233870967742,0.7718908907572428,0.7093277560767307
|
||||||
|
15,0.7211041666666667,0.7437836021505376,0.7589795832633972,0.6981546148177116
|
||||||
|
16,0.724875,0.748067876344086,0.7483940289815267,0.6881407947950465
|
||||||
|
17,0.7293958333333334,0.75,0.7373875479698181,0.6788085179944192
|
||||||
|
18,0.7345,0.7517641129032258,0.7301613558133443,0.6724013622089099
|
||||||
|
19,0.7350416666666667,0.7550403225806451,0.7210039290587107,0.6651292238184201
|
||||||
|
20,0.7391458333333333,0.758736559139785,0.711990216255188,0.6584682823509298
|
||||||
|
21,0.7433958333333334,0.7608366935483871,0.7047773049672444,0.6513981505106854
|
||||||
|
22,0.7442916666666667,0.7634408602150538,0.6997552577654521,0.6454382868864204
|
||||||
|
23,0.7481458333333333,0.7653729838709677,0.6937046423753103,0.6401395256160408
|
||||||
|
24,0.7498333333333334,0.7683131720430108,0.6873725473880767,0.6349253362865859
|
||||||
|
25,0.7529166666666667,0.769573252688172,0.6828897857666015,0.6308066611007977
|
||||||
|
26,0.7547291666666667,0.7708333333333334,0.6747500312328338,0.6260690948655528
|
||||||
|
27,0.7547291666666667,0.7722614247311828,0.6727884339491527,0.6230032514500362
|
||||||
|
28,0.7580208333333334,0.7736055107526881,0.6705719958146413,0.6201616948650729
|
||||||
|
29,0.757125,0.7748655913978495,0.6678872625033061,0.6165956912502166
|
||||||
|
30,0.76025,0.7759576612903226,0.6641453323364258,0.6141596103227267
|
||||||
|
31,0.7609583333333333,0.7767137096774194,0.6619020007451375,0.6114715427480718
|
||||||
|
32,0.7618333333333334,0.776377688172043,0.6576849890549977,0.609050533463878
|
||||||
|
33,0.7629583333333333,0.7783938172043011,0.6566993356545766,0.6069687149857962
|
||||||
|
34,0.7631666666666667,0.7791498655913979,0.6553838003476461,0.6052348261238426
|
||||||
|
35,0.7626666666666667,0.7785618279569892,0.6523663277626037,0.6037159392269709
|
||||||
|
36,0.7653333333333333,0.780325940860215,0.6526273953914642,0.6023694485105494
|
||||||
|
37,0.7639791666666667,0.7810819892473119,0.6524330813884736,0.6012957778669172
|
||||||
|
38,0.7673958333333334,0.7799059139784946,0.6500039516290029,0.6002496516191831
|
||||||
|
39,0.764875,0.7815860215053764,0.6476710465749105,0.5992967845291219
|
||||||
|
40,0.7661041666666667,0.782258064516129,0.6484741485118866,0.5987250285763894
|
||||||
|
41,0.7668333333333334,0.7818380376344086,0.64447620677948,0.5979425240588444
|
||||||
|
42,0.7662291666666666,0.7817540322580645,0.6458657967249553,0.5976644568545844
|
||||||
|
43,0.76675,0.7824260752688172,0.6449048202037811,0.5971484975789183
|
||||||
|
44,0.7677916666666667,0.7825100806451613,0.6437487976551056,0.5968063899906733
|
||||||
|
45,0.7677083333333333,0.7825940860215054,0.6432626353104909,0.5964810960395361
|
||||||
|
46,0.7666041666666666,0.7826780913978495,0.6452286802132925,0.5963560240243071
|
||||||
|
47,0.768875,0.7827620967741935,0.6432473388512929,0.5962550809947393
|
||||||
|
48,0.7674583333333334,0.7826780913978495,0.6433716455300649,0.5961945524779699
|
||||||
|
49,0.7678333333333334,0.7827620967741935,0.6440556212266286,0.596176440036425
|
||||||
|
50,0.7677083333333333,0.7825940860215054,0.6438165396849315,0.5961719805835396
|
|
BIN
data/notebook2/random_qccnn_best.pt
Normal file
BIN
data/notebook2/random_qccnn_best.pt
Normal file
Binary file not shown.
51
data/notebook2/random_qccnn_metrics.csv
Normal file
51
data/notebook2/random_qccnn_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||||
|
1,0.6377291666666667,0.7426075268817204,1.0222952149709066,0.6658745436899124
|
||||||
|
2,0.7588541666666667,0.7799059139784946,0.6436424539883931,0.5876969707909451
|
||||||
|
3,0.7829791666666667,0.7955309139784946,0.5888768948713938,0.5508138095178912
|
||||||
|
4,0.7927916666666667,0.8009912634408602,0.5641531114578247,0.5473136597423143
|
||||||
|
5,0.7997291666666667,0.8065356182795699,0.5444657148520152,0.515406842834206
|
||||||
|
6,0.8043541666666667,0.8140120967741935,0.5316140202681223,0.5052012169873843
|
||||||
|
7,0.80975,0.8156922043010753,0.521293937365214,0.4983856466508681
|
||||||
|
8,0.8155208333333334,0.8162802419354839,0.5079634555180867,0.49958501611986467
|
||||||
|
9,0.8161875,0.8159442204301075,0.5036936206022898,0.4907228536503289
|
||||||
|
10,0.816875,0.8160282258064516,0.497530157327652,0.4895895427914076
|
||||||
|
11,0.8210416666666667,0.8269489247311828,0.4916797243754069,0.47517218673101036
|
||||||
|
12,0.8210416666666667,0.8266129032258065,0.4892559293905894,0.468310099135163
|
||||||
|
13,0.8241041666666666,0.8211525537634409,0.4825246704419454,0.4856182368852759
|
||||||
|
14,0.8239791666666667,0.8264448924731183,0.48390908201535543,0.46895075997998636
|
||||||
|
15,0.8245,0.8230846774193549,0.4795244421164195,0.47176380727880746
|
||||||
|
16,0.8270416666666667,0.829133064516129,0.47462198424339297,0.4671747069205007
|
||||||
|
17,0.8273958333333333,0.8297211021505376,0.4733834793567657,0.4658442559421703
|
||||||
|
18,0.828125,0.8337533602150538,0.47195579528808596,0.46035799896845253
|
||||||
|
19,0.8294583333333333,0.829133064516129,0.4668528276284536,0.46208705472689804
|
||||||
|
20,0.830625,0.8303931451612904,0.46666145197550457,0.4616392642580053
|
||||||
|
21,0.8320416666666667,0.8280409946236559,0.46406093080838523,0.4690516353935324
|
||||||
|
22,0.8315833333333333,0.8293010752688172,0.4628266294002533,0.4590829178210228
|
||||||
|
23,0.8315416666666666,0.8298051075268817,0.46202420779069264,0.4582436110383721
|
||||||
|
24,0.8328541666666667,0.8282090053763441,0.4594616918563843,0.46520241454083433
|
||||||
|
25,0.8321458333333334,0.8313172043010753,0.45899707794189454,0.4611524093535639
|
||||||
|
26,0.835,0.8311491935483871,0.4561348853111267,0.460902286793596
|
||||||
|
27,0.8342708333333333,0.8335013440860215,0.45571692689259846,0.4524733103731627
|
||||||
|
28,0.83475,0.831989247311828,0.45314634958902994,0.4535403392648184
|
||||||
|
29,0.8353958333333333,0.8333333333333334,0.4508692183494568,0.4509869323622796
|
||||||
|
30,0.8358958333333333,0.8314012096774194,0.4508490650653839,0.4559444804345408
|
||||||
|
31,0.8379791666666667,0.8347614247311828,0.44810916344324747,0.45132114586009775
|
||||||
|
32,0.8369375,0.832997311827957,0.4488534228801727,0.4531173186917459
|
||||||
|
33,0.8375208333333334,0.8366935483870968,0.4460577855904897,0.4518213977095901
|
||||||
|
34,0.8372708333333333,0.8314852150537635,0.4453533016045888,0.4571616124081355
|
||||||
|
35,0.8381458333333334,0.8351814516129032,0.4446527805328369,0.4513627043975297
|
||||||
|
36,0.8392916666666667,0.8373655913978495,0.4431237142880758,0.44869983228304056
|
||||||
|
37,0.838375,0.8347614247311828,0.4424218458334605,0.4486082660895522
|
||||||
|
38,0.8397916666666667,0.8355174731182796,0.4412206415732702,0.44703892129723743
|
||||||
|
39,0.8401875,0.8365255376344086,0.4396123299598694,0.4470836206149029
|
||||||
|
40,0.8399791666666667,0.8345934139784946,0.43971687642733254,0.4466231196157394
|
||||||
|
41,0.8410833333333333,0.8373655913978495,0.43810279977321626,0.4461495735312021
|
||||||
|
42,0.8408958333333333,0.8372815860215054,0.4371747978528341,0.4456534408113008
|
||||||
|
43,0.8414583333333333,0.8363575268817204,0.43774009283383686,0.4454879440287108
|
||||||
|
44,0.841375,0.8358534946236559,0.4363077602386475,0.4460145914426414
|
||||||
|
45,0.8408958333333333,0.8363575268817204,0.43665687982241314,0.44518788110825325
|
||||||
|
46,0.841875,0.8371135752688172,0.43598757874965666,0.44520101848468985
|
||||||
|
47,0.8422083333333333,0.836945564516129,0.4349166991710663,0.44520143988311933
|
||||||
|
48,0.8421041666666667,0.8360215053763441,0.43520630804697674,0.44510498354511874
|
||||||
|
49,0.8416875,0.836861559139785,0.4354708949327469,0.44491737311886204
|
||||||
|
50,0.8408125,0.836861559139785,0.43524290529886883,0.4449239748139535
|
|
BIN
data/notebook2/vgg_best.pt
Normal file
BIN
data/notebook2/vgg_best.pt
Normal file
Binary file not shown.
51
data/notebook2/vgg_metrics.csv
Normal file
51
data/notebook2/vgg_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||||
|
1,0.2535833333333333,0.41952284946236557,2.274448096593221,2.2001695325297694
|
||||||
|
2,0.4691875,0.5073084677419355,2.116311640103658,2.051948335862929
|
||||||
|
3,0.5270416666666666,0.5266297043010753,2.009436710357666,1.9777893276624783
|
||||||
|
4,0.5353541666666667,0.535114247311828,1.9530798486073813,1.9356736265203005
|
||||||
|
5,0.5935416666666666,0.6175235215053764,1.9118746633529664,1.891348486305565
|
||||||
|
6,0.636625,0.6625504032258065,1.8629151741663614,1.847881397893352
|
||||||
|
7,0.686625,0.7302587365591398,1.8228577140172322,1.8057919330494379
|
||||||
|
8,0.7415416666666667,0.7511760752688172,1.7803055089314779,1.7647134245082896
|
||||||
|
9,0.7542291666666666,0.7582325268817204,1.7502256110509236,1.7423059594246648
|
||||||
|
10,0.7607708333333333,0.764616935483871,1.7317150996526083,1.724781782396378
|
||||||
|
11,0.7677291666666667,0.7719254032258065,1.7192926041285197,1.713816902970755
|
||||||
|
12,0.7704791666666667,0.7731014784946236,1.71117463239034,1.7070557173862253
|
||||||
|
13,0.771625,0.7736055107526881,1.7053865385055542,1.7027398668309695
|
||||||
|
14,0.7747708333333333,0.7749495967741935,1.7004834445317587,1.6978983327906618
|
||||||
|
15,0.7763333333333333,0.7751176075268817,1.6967166414260864,1.6960316857983988
|
||||||
|
16,0.7790416666666666,0.7767977150537635,1.6929608987172444,1.692371742699736
|
||||||
|
17,0.780125,0.7801579301075269,1.6910814901987712,1.6917471321680213
|
||||||
|
18,0.7817916666666667,0.7827620967741935,1.6882368882497152,1.6861866725388395
|
||||||
|
19,0.7827291666666667,0.7815860215053764,1.6857908573150635,1.6868788478195027
|
||||||
|
20,0.7827083333333333,0.7840221774193549,1.6843910643259685,1.6850647259784002
|
||||||
|
21,0.7848541666666666,0.7857862903225806,1.6823169787724812,1.6813292862266622
|
||||||
|
22,0.7876666666666666,0.7848622311827957,1.6801685171127319,1.6815461574062225
|
||||||
|
23,0.7879791666666667,0.7848622311827957,1.6788572374979656,1.6818229216401295
|
||||||
|
24,0.7890833333333334,0.7859543010752689,1.6776897859573365,1.6799169009731663
|
||||||
|
25,0.7902708333333334,0.7883904569892473,1.6763870159784953,1.6777271570697907
|
||||||
|
26,0.791,0.789986559139785,1.6751194190979004,1.6777111407249206
|
||||||
|
27,0.7918541666666666,0.7886424731182796,1.674490571975708,1.6776156912567795
|
||||||
|
28,0.7936666666666666,0.7905745967741935,1.6728278172810873,1.6744540147883917
|
||||||
|
29,0.793375,0.7876344086021505,1.671903525352478,1.6776354030896259
|
||||||
|
30,0.7945833333333333,0.7901545698924731,1.6710031661987306,1.6751063805754467
|
||||||
|
31,0.7959166666666667,0.7932627688172043,1.6700964994430543,1.6719800926023913
|
||||||
|
32,0.7955833333333333,0.7933467741935484,1.6697869466145834,1.6723043623790945
|
||||||
|
33,0.7971458333333333,0.792002688172043,1.668815224647522,1.6722341519530102
|
||||||
|
34,0.7974375,0.7931787634408602,1.6682115157445272,1.6723163063808153
|
||||||
|
35,0.7979791666666667,0.7947748655913979,1.6676976483662924,1.670254216399244
|
||||||
|
36,0.7984583333333334,0.7950268817204301,1.6669070380528768,1.6704239499184392
|
||||||
|
37,0.799,0.795866935483871,1.6664897476832072,1.6691849411174815
|
||||||
|
38,0.7991875,0.7948588709677419,1.6660230029424032,1.6700230785595473
|
||||||
|
39,0.800125,0.795950940860215,1.6655609397888183,1.6688298884258475
|
||||||
|
40,0.8006458333333333,0.7956989247311828,1.6652250595092772,1.6690674840763051
|
||||||
|
41,0.8006875,0.7962869623655914,1.664971160888672,1.6688041135828982
|
||||||
|
42,0.8015416666666667,0.795866935483871,1.6647087678909303,1.668453808753721
|
||||||
|
43,0.8013958333333333,0.7962029569892473,1.6643381754557292,1.6682484457569737
|
||||||
|
44,0.8016458333333333,0.7965389784946236,1.6641663211186728,1.668392535178892
|
||||||
|
45,0.8017708333333333,0.7957829301075269,1.6640429185231527,1.6681856429705055
|
||||||
|
46,0.8021041666666666,0.7965389784946236,1.663849822362264,1.6680005634984663
|
||||||
|
47,0.8024583333333334,0.7966229838709677,1.663770879427592,1.6679569816076627
|
||||||
|
48,0.8025833333333333,0.7962029569892473,1.6636734215418498,1.6680130330465173
|
||||||
|
49,0.8024583333333334,0.7970430107526881,1.6635978577931723,1.6679544077124646
|
||||||
|
50,0.8025,0.7970430107526881,1.6635685895284016,1.6679527682642783
|
|
Loading…
Reference in New Issue
Block a user