refactor(Modify): 重构 Modify.py 脚本

- 添加数据增强、优化器与学习率调度等改进
- 统一训练框架,支持多模型训练- 增加梯度裁剪、最佳模型保存等功能
- 优化代码结构,提高可维护性
This commit is contained in:
fly6516 2025-06-25 03:04:30 +08:00
parent b1818286f0
commit a6c92a4031
17 changed files with 1600 additions and 0 deletions

696
Modify.ipynb Normal file

File diff suppressed because one or more lines are too long

291
Modify.py Normal file
View File

@ -0,0 +1,291 @@
# Modify.py
#%% 导入所有需要的包
import os
import random
import numpy as np
import pandas as pd
import deepquantum as dq
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from tqdm import tqdm
from torch.utils.data import DataLoader
from multiprocessing import freeze_support
#%% 设置随机种子以保证可复现
def seed_torch(seed=1024):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
#%% 准确率计算函数
def calculate_score(y_true, y_preds):
preds_prob = torch.softmax(y_preds, dim=1)
preds_class = torch.argmax(preds_prob, dim=1)
correct = (preds_class == y_true).float()
return (correct.sum() / len(correct)).cpu().numpy()
#%% 训练与验证函数
def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):
model.to(device)
best_acc = 0.0
metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}
for epoch in range(1, num_epochs + 1):
# --- 训练阶段 ---
model.train()
running_loss, running_acc = 0.0, 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item()
running_acc += calculate_score(labels, outputs)
train_loss = running_loss / len(train_loader)
train_acc = running_acc / len(train_loader)
scheduler.step()
# --- 验证阶段 ---
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for imgs, labels in valid_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_acc += calculate_score(labels, outputs)
valid_loss = val_loss / len(valid_loader)
valid_acc = val_acc / len(valid_loader)
metrics['epoch'].append(epoch)
metrics['train_loss'].append(train_loss)
metrics['valid_loss'].append(valid_loss)
metrics['train_acc'].append(train_acc)
metrics['valid_acc'].append(valid_acc)
tqdm.write(f"[{save_path}] Epoch {epoch}/{num_epochs} "
f"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}")
if valid_acc > best_acc:
best_acc = valid_acc
torch.save(model.state_dict(), save_path)
return model, metrics
#%% 测试函数
def test_model(model, test_loader, device):
model.to(device).eval()
acc = 0.0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
acc += calculate_score(labels, outputs)
acc /= len(test_loader)
print(f"Test Accuracy: {acc:.4f}")
return acc
#%% 定义量子卷积层与模型
singlegate_list = ['rx','ry','rz','s','t','p','u3']
doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']
class RandomQuantumConvolutionalLayer(nn.Module):
def __init__(self, nqubit, num_circuits, seed=1024):
super().__init__()
random.seed(seed)
self.nqubit = nqubit
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
def circuit(self, nqubit):
cir = dq.QubitCircuit(nqubit)
cir.rxlayer(encode=True); cir.barrier()
for _ in range(3):
for i in range(nqubit):
getattr(cir, random.choice(singlegate_list))(i)
c,t = random.sample(range(nqubit),2)
gate = random.choice(doublegate_list)
if gate[0] in ['r','s']:
getattr(cir, gate)([c,t])
else:
getattr(cir, gate)(c,t)
cir.barrier()
cir.observable(0)
return cir
def forward(self, x):
k,s = 2,2
x_unf = x.unfold(2,k,s).unfold(3,k,s)
w = (x.shape[-1]-k)//s + 1
x_r = x_unf.reshape(-1, self.nqubit)
exps = []
for cir in self.cirs:
cir(x_r)
exps.append(cir.expectation())
exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)
return exps
class RandomQCCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
RandomQuantumConvolutionalLayer(4,3,seed=1024),
nn.ReLU(), nn.MaxPool2d(2,1),
nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)
)
self.fc = nn.Sequential(
nn.Linear(6*6*6,1024), nn.Dropout(0.4),
nn.Linear(1024,10)
)
def forward(self,x):
x = self.conv(x)
x = x.view(x.size(0),-1)
return self.fc(x)
class ParameterizedQuantumConvolutionalLayer(nn.Module):
def __init__(self,nqubit,num_circuits):
super().__init__()
self.nqubit = nqubit
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
def circuit(self,nqubit):
cir = dq.QubitCircuit(nqubit)
cir.rxlayer(encode=True); cir.barrier()
for _ in range(4):
cir.rylayer(); cir.cnot_ring(); cir.barrier()
cir.observable(0)
return cir
def forward(self,x):
k,s = 2,2
x_unf = x.unfold(2,k,s).unfold(3,k,s)
w = (x.shape[-1]-k)//s +1
x_r = x_unf.reshape(-1,self.nqubit)
exps = []
for cir in self.cirs:
cir(x_r); exps.append(cir.expectation())
exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)
return exps
class QCCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
ParameterizedQuantumConvolutionalLayer(4,3),
nn.ReLU(), nn.MaxPool2d(2,1)
)
self.fc = nn.Sequential(
nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),
nn.Linear(128,10)
)
def forward(self,x):
x = self.conv(x); x = x.view(x.size(0),-1)
return self.fc(x)
def vgg_block(in_c,out_c,n_convs):
layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]
for _ in range(n_convs-1):
layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]
layers.append(nn.MaxPool2d(2,2))
return nn.Sequential(*layers)
VGG = nn.Sequential(
vgg_block(1,10,3),
vgg_block(10,16,3),
nn.Flatten(),
nn.Linear(16*4*4,120), nn.Sigmoid(),
nn.Linear(120,84), nn.Sigmoid(),
nn.Linear(84,10), nn.Softmax(dim=-1)
)
#%% 主入口
if __name__ == '__main__':
freeze_support()
# 数据增广与加载
train_transform = transforms.Compose([
transforms.Resize((18, 18)),
transforms.RandomRotation(15),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(0.3),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
eval_transform = transforms.Compose([
transforms.Resize((18, 18)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)
test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)
train_size = int(0.8 * len(full_train))
valid_size = len(full_train) - train_size
train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])
valid_ds.dataset.transform = eval_transform
batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
# 三种模型配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {
'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),
'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),
'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')
}
all_metrics = {}
for name, (model, lr, save_path) in models.items():
seed_torch(1024)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
print(f"\n=== Training {name} ===")
_, metrics = train_model(
model, criterion, optimizer, scheduler,
train_loader, valid_loader,
num_epochs=50, device=device, save_path=save_path
)
all_metrics[name] = metrics
pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)
# 测试与可视化
plt.figure(figsize=(12,5))
for i,(name,metrics) in enumerate(all_metrics.items(),1):
model, _, save_path = models[name]
best_model = model.to(device)
best_model.load_state_dict(torch.load(save_path))
print(f"\n--- Testing {name} ---")
test_model(best_model, test_loader, device)
plt.subplot(1,3,i)
plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')
plt.xlabel('Epoch'); plt.ylabel('Valid Acc')
plt.title(name); plt.legend()
plt.tight_layout(); plt.show()
# 参数量统计
def count_parameters(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad)
print("\nParameter Counts:")
for name,(model,_,_) in models.items():
print(f"{name}: {count_parameters(model)}")

460
Origin.py Normal file
View File

@ -0,0 +1,460 @@
#%%
# 首先我们导入所有需要的包:
import os
import random
import numpy as np
import pandas as pd
import deepquantum as dq
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
# from torchvision.datasets import MNIST, FashionMNIST
def seed_torch(seed=1024):
"""
Set random seeds for reproducibility.
Args:
seed (int): Random seed number to use. Default is 1024.
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Seed all GPUs with the same seed if using multi-GPU
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch(1024)
#%%
def calculate_score(y_true, y_preds):
# 将模型预测结果转为概率分布
preds_prob = torch.softmax(y_preds, dim=1)
# 获得预测的类别(概率最高的一类)
preds_class = torch.argmax(preds_prob, dim=1)
# 计算准确率
correct = (preds_class == y_true).float()
accuracy = correct.sum() / len(correct)
return accuracy.cpu().numpy()
def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device):
"""
训练和验证模型
Args:
model (torch.nn.Module): 要训练的模型
criterion (torch.nn.Module): 损失函数
optimizer (torch.optim.Optimizer): 优化器
train_loader (torch.utils.data.DataLoader): 训练数据加载器
valid_loader (torch.utils.data.DataLoader): 验证数据加载器
num_epochs (int): 训练的epoch数
Returns:
model (torch.nn.Module): 训练后的模型
"""
model.train()
train_loss_list = []
valid_loss_list = []
train_acc_list = []
valid_acc_list = []
with tqdm(total=num_epochs) as pbar:
for epoch in range(num_epochs):
# 训练阶段
train_loss = 0.0
train_acc = 0.0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += calculate_score(labels, outputs)
train_loss /= len(train_loader)
train_acc /= len(train_loader)
# 验证阶段
model.eval()
valid_loss = 0.0
valid_acc = 0.0
with torch.no_grad():
for images, labels in valid_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
valid_loss += loss.item()
valid_acc += calculate_score(labels, outputs)
valid_loss /= len(valid_loader)
valid_acc /= len(valid_loader)
pbar.set_description(f"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}")
pbar.update()
train_loss_list.append(train_loss)
valid_loss_list.append(valid_loss)
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
metrics = {'epoch': list(range(1, num_epochs + 1)),
'train_acc': train_acc_list,
'valid_acc': valid_acc_list,
'train_loss': train_loss_list,
'valid_loss': valid_loss_list}
return model, metrics
def test_model(model, test_loader, device):
model.eval()
test_acc = 0.0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
test_acc += calculate_score(labels, outputs)
test_acc /= len(test_loader)
print(f'Test Acc: {test_acc:.3f}')
return test_acc
#%%
# 定义图像变换
trans1 = transforms.Compose([
transforms.Resize((18, 18)), # 调整大小为18x18
transforms.ToTensor() # 转换为张量
])
trans2 = transforms.Compose([
transforms.Resize((16, 16)), # 调整大小为16x16
transforms.ToTensor() # 转换为张量
])
train_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
test_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
# 定义训练集和测试集的比例
train_ratio = 0.8 # 训练集比例为80%验证集比例为20%
valid_ratio = 0.2
total_samples = len(train_dataset)
train_size = int(train_ratio * total_samples)
valid_size = int(valid_ratio * total_samples)
# 分割训练集和测试集
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])
# 加载随机抽取的训练数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)
#%%
singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3']
doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz']
#%%
# 随机量子卷积层
class RandomQuantumConvolutionalLayer(nn.Module):
def __init__(self, nqubit, num_circuits, seed:int=1024):
super(RandomQuantumConvolutionalLayer, self).__init__()
random.seed(seed)
self.nqubit = nqubit
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
def circuit(self, nqubit):
cir = dq.QubitCircuit(nqubit)
cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
cir.barrier()
for iter in range(3):
for i in range(nqubit):
singlegate = random.choice(singlegate_list)
getattr(cir, singlegate)(i)
control_bit, target_bit = random.sample(range(0, nqubit - 1), 2)
doublegate = random.choice(doublegate_list)
if doublegate[0] in ['r', 's']:
getattr(cir, doublegate)([control_bit, target_bit])
else:
getattr(cir, doublegate)(control_bit, target_bit)
cir.barrier()
cir.observable(0)
return cir
def forward(self, x):
kernel_size, stride = 2, 2
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
w = int((x.shape[-1] - kernel_size) / stride + 1)
x_reshape = x_unflod.reshape(-1, self.nqubit)
exps = []
for cir in self.cirs: # out_channels
cir(x_reshape)
exp = cir.expectation()
exps.append(exp)
exps = torch.stack(exps, dim=1)
exps = exps.reshape(x.shape[0], 3, w, w)
return exps
#%%
net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024)
net.cirs[0].draw()
#%%
# 基于随机量子卷积层的混合模型
class RandomQCCNN(nn.Module):
def __init__(self):
super(RandomQCCNN, self).__init__()
self.conv = nn.Sequential(
RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1),
nn.Conv2d(3, 6, kernel_size=2, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1)
)
self.fc = nn.Sequential(
nn.Linear(6 * 6 * 6, 1024),
nn.Dropout(0.4),
nn.Linear(1024, 10)
)
def forward(self, x):
x = self.conv(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x
#%%
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
seed_torch(1024) # 重新设置随机种子
model = RandomQCCNN()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加正则化项
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
torch.save(optim_model.state_dict(), './data/notebook1/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
pd.DataFrame(metrics).to_csv('./data/notebook1/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
#%%
state_dict = torch.load('./data/notebook1/random_qccnn_weights.pt', map_location=device)
random_qccnn_model = RandomQCCNN()
random_qccnn_model.load_state_dict(state_dict)
random_qccnn_model.to(device)
test_acc = test_model(random_qccnn_model, test_loader, device)
#%%
data = pd.read_csv('./data/notebook1/random_qccnn_metrics.csv')
epoch = data['epoch']
train_loss = data['train_loss']
valid_loss = data['valid_loss']
train_acc = data['train_acc']
valid_acc = data['valid_acc']
# 创建图和Axes对象
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 绘制训练损失曲线
ax1.plot(epoch, train_loss, label='Train Loss')
ax1.plot(epoch, valid_loss, label='Valid Loss')
ax1.set_title('Training Loss Curve')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
# 绘制训练准确率曲线
ax2.plot(epoch, train_acc, label='Train Accuracy')
ax2.plot(epoch, valid_acc, label='Valid Accuracy')
ax2.set_title('Training Accuracy Curve')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
plt.show()
#%%
class ParameterizedQuantumConvolutionalLayer(nn.Module):
def __init__(self, nqubit, num_circuits):
super().__init__()
self.nqubit = nqubit
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
def circuit(self, nqubit):
cir = dq.QubitCircuit(nqubit)
cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
cir.barrier()
for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4可控参数一共16个
cir.rylayer()
cir.cnot_ring()
cir.barrier()
cir.observable(0)
return cir
def forward(self, x):
kernel_size, stride = 2, 2
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
w = int((x.shape[-1] - kernel_size) / stride + 1)
x_reshape = x_unflod.reshape(-1, self.nqubit)
exps = []
for cir in self.cirs: # out_channels
cir(x_reshape)
exp = cir.expectation()
exps.append(exp)
exps = torch.stack(exps, dim=1)
exps = exps.reshape(x.shape[0], 3, w, w)
return exps
#%%
# 此处我们可视化其中一个量子卷积核的线路结构:
net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3)
net.cirs[0].draw()
#%%
# QCCNN整体网络架构
class QCCNN(nn.Module):
def __init__(self):
super(QCCNN, self).__init__()
self.conv = nn.Sequential(
ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1)
)
self.fc = nn.Sequential(
nn.Linear(8 * 8 * 3, 128),
nn.Dropout(0.4),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x
#%%
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QCCNN()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5) # 添加正则化项
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
torch.save(optim_model.state_dict(), './data/notebook1/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
pd.DataFrame(metrics).to_csv('./data/notebook1/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
#%%
state_dict = torch.load('./data/notebook1/qccnn_weights.pt', map_location=device)
qccnn_model = QCCNN()
qccnn_model.load_state_dict(state_dict)
qccnn_model.to(device)
test_acc = test_model(qccnn_model, test_loader, device)
#%%
def vgg_block(in_channel,out_channel,num_convs):
layers = nn.ModuleList()
assert num_convs >= 1
layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1))
layers.append(nn.ReLU())
for _ in range(num_convs-1):
layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1))
layers.append(nn.ReLU())
layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
return nn.Sequential(*layers)
VGG = nn.Sequential(
vgg_block(1,10,3), # 14,14
vgg_block(10,16,3), # 4 * 4
nn.Flatten(),
nn.Linear(16 * 4 * 4, 120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84,10),
nn.Softmax(dim=-1)
)
#%%
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg_model = VGG
vgg_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg_model.parameters(), lr=1e-5) # 添加正则化项
vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
torch.save(vgg_model.state_dict(), './data/notebook1/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
pd.DataFrame(metrics).to_csv('./data/notebook1/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
#%%
state_dict = torch.load('./data/notebook1/vgg_weights.pt', map_location=device)
vgg_model = VGG
vgg_model.load_state_dict(state_dict)
vgg_model.to(device)
vgg_test_acc = test_model(vgg_model, test_loader, device)
#%%
vgg_data = pd.read_csv('./data/notebook1/vgg_metrics.csv')
qccnn_data = pd.read_csv('./data/notebook1/qccnn_metrics.csv')
vgg_epoch = vgg_data['epoch']
vgg_train_loss = vgg_data['train_loss']
vgg_valid_loss = vgg_data['valid_loss']
vgg_train_acc = vgg_data['train_acc']
vgg_valid_acc = vgg_data['valid_acc']
qccnn_epoch = qccnn_data['epoch']
qccnn_train_loss = qccnn_data['train_loss']
qccnn_valid_loss = qccnn_data['valid_loss']
qccnn_train_acc = qccnn_data['train_acc']
qccnn_valid_acc = qccnn_data['valid_acc']
# 创建图和Axes对象
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 绘制训练损失曲线
ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss')
ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss')
ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss')
ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss')
ax1.set_title('Training Loss Curve')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
# 绘制训练准确率曲线
ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy')
ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy')
ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy')
ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy')
ax2.set_title('Training Accuracy Curve')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
plt.show()
#%%
# 这里我们对比不同模型之间可训练参数量的区别
def count_parameters(model):
"""
计算模型的参数数量
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
number_params_VGG = count_parameters(VGG)
number_params_QCCNN = count_parameters(QCCNN())
print(f'VGG 模型可训练参数量:{number_params_VGG}\t QCCNN模型可训练参数量{number_params_QCCNN}')

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,51 @@
epoch,train_acc,valid_acc,train_loss,valid_loss
1,0.3112916666666667,0.5076444892473119,2.1129255460103353,1.7863053237238238
2,0.46210416666666665,0.5547715053763441,1.5605127541224162,1.3499009532313193
3,0.5200833333333333,0.5892137096774194,1.298180630683899,1.1617528520604616
4,0.5641666666666667,0.6311323924731183,1.161390997727712,1.0419180438082705
5,0.6055833333333334,0.6598622311827957,1.062113331158956,0.9560088053826363
6,0.6316666666666667,0.6804435483870968,0.9944041889508565,0.8958990977656457
7,0.6513958333333333,0.6928763440860215,0.9395476338068645,0.8504296951396491
8,0.6677708333333333,0.7007728494623656,0.8985106336275737,0.8148919709267155
9,0.6809375,0.7095934139784946,0.8673048818906148,0.7897684433126962
10,0.6876666666666666,0.7185819892473119,0.8408271819750468,0.7671139176173877
11,0.697625,0.7247143817204301,0.8202435002326965,0.7502269674372929
12,0.7030625,0.7303427419354839,0.8024592914581299,0.7343912464316174
13,0.7118958333333333,0.7316868279569892,0.7869214735031128,0.7222210591839205
14,0.7163958333333333,0.7404233870967742,0.7718908907572428,0.7093277560767307
15,0.7211041666666667,0.7437836021505376,0.7589795832633972,0.6981546148177116
16,0.724875,0.748067876344086,0.7483940289815267,0.6881407947950465
17,0.7293958333333334,0.75,0.7373875479698181,0.6788085179944192
18,0.7345,0.7517641129032258,0.7301613558133443,0.6724013622089099
19,0.7350416666666667,0.7550403225806451,0.7210039290587107,0.6651292238184201
20,0.7391458333333333,0.758736559139785,0.711990216255188,0.6584682823509298
21,0.7433958333333334,0.7608366935483871,0.7047773049672444,0.6513981505106854
22,0.7442916666666667,0.7634408602150538,0.6997552577654521,0.6454382868864204
23,0.7481458333333333,0.7653729838709677,0.6937046423753103,0.6401395256160408
24,0.7498333333333334,0.7683131720430108,0.6873725473880767,0.6349253362865859
25,0.7529166666666667,0.769573252688172,0.6828897857666015,0.6308066611007977
26,0.7547291666666667,0.7708333333333334,0.6747500312328338,0.6260690948655528
27,0.7547291666666667,0.7722614247311828,0.6727884339491527,0.6230032514500362
28,0.7580208333333334,0.7736055107526881,0.6705719958146413,0.6201616948650729
29,0.757125,0.7748655913978495,0.6678872625033061,0.6165956912502166
30,0.76025,0.7759576612903226,0.6641453323364258,0.6141596103227267
31,0.7609583333333333,0.7767137096774194,0.6619020007451375,0.6114715427480718
32,0.7618333333333334,0.776377688172043,0.6576849890549977,0.609050533463878
33,0.7629583333333333,0.7783938172043011,0.6566993356545766,0.6069687149857962
34,0.7631666666666667,0.7791498655913979,0.6553838003476461,0.6052348261238426
35,0.7626666666666667,0.7785618279569892,0.6523663277626037,0.6037159392269709
36,0.7653333333333333,0.780325940860215,0.6526273953914642,0.6023694485105494
37,0.7639791666666667,0.7810819892473119,0.6524330813884736,0.6012957778669172
38,0.7673958333333334,0.7799059139784946,0.6500039516290029,0.6002496516191831
39,0.764875,0.7815860215053764,0.6476710465749105,0.5992967845291219
40,0.7661041666666667,0.782258064516129,0.6484741485118866,0.5987250285763894
41,0.7668333333333334,0.7818380376344086,0.64447620677948,0.5979425240588444
42,0.7662291666666666,0.7817540322580645,0.6458657967249553,0.5976644568545844
43,0.76675,0.7824260752688172,0.6449048202037811,0.5971484975789183
44,0.7677916666666667,0.7825100806451613,0.6437487976551056,0.5968063899906733
45,0.7677083333333333,0.7825940860215054,0.6432626353104909,0.5964810960395361
46,0.7666041666666666,0.7826780913978495,0.6452286802132925,0.5963560240243071
47,0.768875,0.7827620967741935,0.6432473388512929,0.5962550809947393
48,0.7674583333333334,0.7826780913978495,0.6433716455300649,0.5961945524779699
49,0.7678333333333334,0.7827620967741935,0.6440556212266286,0.596176440036425
50,0.7677083333333333,0.7825940860215054,0.6438165396849315,0.5961719805835396
1 epoch train_acc valid_acc train_loss valid_loss
2 1 0.3112916666666667 0.5076444892473119 2.1129255460103353 1.7863053237238238
3 2 0.46210416666666665 0.5547715053763441 1.5605127541224162 1.3499009532313193
4 3 0.5200833333333333 0.5892137096774194 1.298180630683899 1.1617528520604616
5 4 0.5641666666666667 0.6311323924731183 1.161390997727712 1.0419180438082705
6 5 0.6055833333333334 0.6598622311827957 1.062113331158956 0.9560088053826363
7 6 0.6316666666666667 0.6804435483870968 0.9944041889508565 0.8958990977656457
8 7 0.6513958333333333 0.6928763440860215 0.9395476338068645 0.8504296951396491
9 8 0.6677708333333333 0.7007728494623656 0.8985106336275737 0.8148919709267155
10 9 0.6809375 0.7095934139784946 0.8673048818906148 0.7897684433126962
11 10 0.6876666666666666 0.7185819892473119 0.8408271819750468 0.7671139176173877
12 11 0.697625 0.7247143817204301 0.8202435002326965 0.7502269674372929
13 12 0.7030625 0.7303427419354839 0.8024592914581299 0.7343912464316174
14 13 0.7118958333333333 0.7316868279569892 0.7869214735031128 0.7222210591839205
15 14 0.7163958333333333 0.7404233870967742 0.7718908907572428 0.7093277560767307
16 15 0.7211041666666667 0.7437836021505376 0.7589795832633972 0.6981546148177116
17 16 0.724875 0.748067876344086 0.7483940289815267 0.6881407947950465
18 17 0.7293958333333334 0.75 0.7373875479698181 0.6788085179944192
19 18 0.7345 0.7517641129032258 0.7301613558133443 0.6724013622089099
20 19 0.7350416666666667 0.7550403225806451 0.7210039290587107 0.6651292238184201
21 20 0.7391458333333333 0.758736559139785 0.711990216255188 0.6584682823509298
22 21 0.7433958333333334 0.7608366935483871 0.7047773049672444 0.6513981505106854
23 22 0.7442916666666667 0.7634408602150538 0.6997552577654521 0.6454382868864204
24 23 0.7481458333333333 0.7653729838709677 0.6937046423753103 0.6401395256160408
25 24 0.7498333333333334 0.7683131720430108 0.6873725473880767 0.6349253362865859
26 25 0.7529166666666667 0.769573252688172 0.6828897857666015 0.6308066611007977
27 26 0.7547291666666667 0.7708333333333334 0.6747500312328338 0.6260690948655528
28 27 0.7547291666666667 0.7722614247311828 0.6727884339491527 0.6230032514500362
29 28 0.7580208333333334 0.7736055107526881 0.6705719958146413 0.6201616948650729
30 29 0.757125 0.7748655913978495 0.6678872625033061 0.6165956912502166
31 30 0.76025 0.7759576612903226 0.6641453323364258 0.6141596103227267
32 31 0.7609583333333333 0.7767137096774194 0.6619020007451375 0.6114715427480718
33 32 0.7618333333333334 0.776377688172043 0.6576849890549977 0.609050533463878
34 33 0.7629583333333333 0.7783938172043011 0.6566993356545766 0.6069687149857962
35 34 0.7631666666666667 0.7791498655913979 0.6553838003476461 0.6052348261238426
36 35 0.7626666666666667 0.7785618279569892 0.6523663277626037 0.6037159392269709
37 36 0.7653333333333333 0.780325940860215 0.6526273953914642 0.6023694485105494
38 37 0.7639791666666667 0.7810819892473119 0.6524330813884736 0.6012957778669172
39 38 0.7673958333333334 0.7799059139784946 0.6500039516290029 0.6002496516191831
40 39 0.764875 0.7815860215053764 0.6476710465749105 0.5992967845291219
41 40 0.7661041666666667 0.782258064516129 0.6484741485118866 0.5987250285763894
42 41 0.7668333333333334 0.7818380376344086 0.64447620677948 0.5979425240588444
43 42 0.7662291666666666 0.7817540322580645 0.6458657967249553 0.5976644568545844
44 43 0.76675 0.7824260752688172 0.6449048202037811 0.5971484975789183
45 44 0.7677916666666667 0.7825100806451613 0.6437487976551056 0.5968063899906733
46 45 0.7677083333333333 0.7825940860215054 0.6432626353104909 0.5964810960395361
47 46 0.7666041666666666 0.7826780913978495 0.6452286802132925 0.5963560240243071
48 47 0.768875 0.7827620967741935 0.6432473388512929 0.5962550809947393
49 48 0.7674583333333334 0.7826780913978495 0.6433716455300649 0.5961945524779699
50 49 0.7678333333333334 0.7827620967741935 0.6440556212266286 0.596176440036425
51 50 0.7677083333333333 0.7825940860215054 0.6438165396849315 0.5961719805835396

Binary file not shown.

View File

@ -0,0 +1,51 @@
epoch,train_acc,valid_acc,train_loss,valid_loss
1,0.6377291666666667,0.7426075268817204,1.0222952149709066,0.6658745436899124
2,0.7588541666666667,0.7799059139784946,0.6436424539883931,0.5876969707909451
3,0.7829791666666667,0.7955309139784946,0.5888768948713938,0.5508138095178912
4,0.7927916666666667,0.8009912634408602,0.5641531114578247,0.5473136597423143
5,0.7997291666666667,0.8065356182795699,0.5444657148520152,0.515406842834206
6,0.8043541666666667,0.8140120967741935,0.5316140202681223,0.5052012169873843
7,0.80975,0.8156922043010753,0.521293937365214,0.4983856466508681
8,0.8155208333333334,0.8162802419354839,0.5079634555180867,0.49958501611986467
9,0.8161875,0.8159442204301075,0.5036936206022898,0.4907228536503289
10,0.816875,0.8160282258064516,0.497530157327652,0.4895895427914076
11,0.8210416666666667,0.8269489247311828,0.4916797243754069,0.47517218673101036
12,0.8210416666666667,0.8266129032258065,0.4892559293905894,0.468310099135163
13,0.8241041666666666,0.8211525537634409,0.4825246704419454,0.4856182368852759
14,0.8239791666666667,0.8264448924731183,0.48390908201535543,0.46895075997998636
15,0.8245,0.8230846774193549,0.4795244421164195,0.47176380727880746
16,0.8270416666666667,0.829133064516129,0.47462198424339297,0.4671747069205007
17,0.8273958333333333,0.8297211021505376,0.4733834793567657,0.4658442559421703
18,0.828125,0.8337533602150538,0.47195579528808596,0.46035799896845253
19,0.8294583333333333,0.829133064516129,0.4668528276284536,0.46208705472689804
20,0.830625,0.8303931451612904,0.46666145197550457,0.4616392642580053
21,0.8320416666666667,0.8280409946236559,0.46406093080838523,0.4690516353935324
22,0.8315833333333333,0.8293010752688172,0.4628266294002533,0.4590829178210228
23,0.8315416666666666,0.8298051075268817,0.46202420779069264,0.4582436110383721
24,0.8328541666666667,0.8282090053763441,0.4594616918563843,0.46520241454083433
25,0.8321458333333334,0.8313172043010753,0.45899707794189454,0.4611524093535639
26,0.835,0.8311491935483871,0.4561348853111267,0.460902286793596
27,0.8342708333333333,0.8335013440860215,0.45571692689259846,0.4524733103731627
28,0.83475,0.831989247311828,0.45314634958902994,0.4535403392648184
29,0.8353958333333333,0.8333333333333334,0.4508692183494568,0.4509869323622796
30,0.8358958333333333,0.8314012096774194,0.4508490650653839,0.4559444804345408
31,0.8379791666666667,0.8347614247311828,0.44810916344324747,0.45132114586009775
32,0.8369375,0.832997311827957,0.4488534228801727,0.4531173186917459
33,0.8375208333333334,0.8366935483870968,0.4460577855904897,0.4518213977095901
34,0.8372708333333333,0.8314852150537635,0.4453533016045888,0.4571616124081355
35,0.8381458333333334,0.8351814516129032,0.4446527805328369,0.4513627043975297
36,0.8392916666666667,0.8373655913978495,0.4431237142880758,0.44869983228304056
37,0.838375,0.8347614247311828,0.4424218458334605,0.4486082660895522
38,0.8397916666666667,0.8355174731182796,0.4412206415732702,0.44703892129723743
39,0.8401875,0.8365255376344086,0.4396123299598694,0.4470836206149029
40,0.8399791666666667,0.8345934139784946,0.43971687642733254,0.4466231196157394
41,0.8410833333333333,0.8373655913978495,0.43810279977321626,0.4461495735312021
42,0.8408958333333333,0.8372815860215054,0.4371747978528341,0.4456534408113008
43,0.8414583333333333,0.8363575268817204,0.43774009283383686,0.4454879440287108
44,0.841375,0.8358534946236559,0.4363077602386475,0.4460145914426414
45,0.8408958333333333,0.8363575268817204,0.43665687982241314,0.44518788110825325
46,0.841875,0.8371135752688172,0.43598757874965666,0.44520101848468985
47,0.8422083333333333,0.836945564516129,0.4349166991710663,0.44520143988311933
48,0.8421041666666667,0.8360215053763441,0.43520630804697674,0.44510498354511874
49,0.8416875,0.836861559139785,0.4354708949327469,0.44491737311886204
50,0.8408125,0.836861559139785,0.43524290529886883,0.4449239748139535
1 epoch train_acc valid_acc train_loss valid_loss
2 1 0.6377291666666667 0.7426075268817204 1.0222952149709066 0.6658745436899124
3 2 0.7588541666666667 0.7799059139784946 0.6436424539883931 0.5876969707909451
4 3 0.7829791666666667 0.7955309139784946 0.5888768948713938 0.5508138095178912
5 4 0.7927916666666667 0.8009912634408602 0.5641531114578247 0.5473136597423143
6 5 0.7997291666666667 0.8065356182795699 0.5444657148520152 0.515406842834206
7 6 0.8043541666666667 0.8140120967741935 0.5316140202681223 0.5052012169873843
8 7 0.80975 0.8156922043010753 0.521293937365214 0.4983856466508681
9 8 0.8155208333333334 0.8162802419354839 0.5079634555180867 0.49958501611986467
10 9 0.8161875 0.8159442204301075 0.5036936206022898 0.4907228536503289
11 10 0.816875 0.8160282258064516 0.497530157327652 0.4895895427914076
12 11 0.8210416666666667 0.8269489247311828 0.4916797243754069 0.47517218673101036
13 12 0.8210416666666667 0.8266129032258065 0.4892559293905894 0.468310099135163
14 13 0.8241041666666666 0.8211525537634409 0.4825246704419454 0.4856182368852759
15 14 0.8239791666666667 0.8264448924731183 0.48390908201535543 0.46895075997998636
16 15 0.8245 0.8230846774193549 0.4795244421164195 0.47176380727880746
17 16 0.8270416666666667 0.829133064516129 0.47462198424339297 0.4671747069205007
18 17 0.8273958333333333 0.8297211021505376 0.4733834793567657 0.4658442559421703
19 18 0.828125 0.8337533602150538 0.47195579528808596 0.46035799896845253
20 19 0.8294583333333333 0.829133064516129 0.4668528276284536 0.46208705472689804
21 20 0.830625 0.8303931451612904 0.46666145197550457 0.4616392642580053
22 21 0.8320416666666667 0.8280409946236559 0.46406093080838523 0.4690516353935324
23 22 0.8315833333333333 0.8293010752688172 0.4628266294002533 0.4590829178210228
24 23 0.8315416666666666 0.8298051075268817 0.46202420779069264 0.4582436110383721
25 24 0.8328541666666667 0.8282090053763441 0.4594616918563843 0.46520241454083433
26 25 0.8321458333333334 0.8313172043010753 0.45899707794189454 0.4611524093535639
27 26 0.835 0.8311491935483871 0.4561348853111267 0.460902286793596
28 27 0.8342708333333333 0.8335013440860215 0.45571692689259846 0.4524733103731627
29 28 0.83475 0.831989247311828 0.45314634958902994 0.4535403392648184
30 29 0.8353958333333333 0.8333333333333334 0.4508692183494568 0.4509869323622796
31 30 0.8358958333333333 0.8314012096774194 0.4508490650653839 0.4559444804345408
32 31 0.8379791666666667 0.8347614247311828 0.44810916344324747 0.45132114586009775
33 32 0.8369375 0.832997311827957 0.4488534228801727 0.4531173186917459
34 33 0.8375208333333334 0.8366935483870968 0.4460577855904897 0.4518213977095901
35 34 0.8372708333333333 0.8314852150537635 0.4453533016045888 0.4571616124081355
36 35 0.8381458333333334 0.8351814516129032 0.4446527805328369 0.4513627043975297
37 36 0.8392916666666667 0.8373655913978495 0.4431237142880758 0.44869983228304056
38 37 0.838375 0.8347614247311828 0.4424218458334605 0.4486082660895522
39 38 0.8397916666666667 0.8355174731182796 0.4412206415732702 0.44703892129723743
40 39 0.8401875 0.8365255376344086 0.4396123299598694 0.4470836206149029
41 40 0.8399791666666667 0.8345934139784946 0.43971687642733254 0.4466231196157394
42 41 0.8410833333333333 0.8373655913978495 0.43810279977321626 0.4461495735312021
43 42 0.8408958333333333 0.8372815860215054 0.4371747978528341 0.4456534408113008
44 43 0.8414583333333333 0.8363575268817204 0.43774009283383686 0.4454879440287108
45 44 0.841375 0.8358534946236559 0.4363077602386475 0.4460145914426414
46 45 0.8408958333333333 0.8363575268817204 0.43665687982241314 0.44518788110825325
47 46 0.841875 0.8371135752688172 0.43598757874965666 0.44520101848468985
48 47 0.8422083333333333 0.836945564516129 0.4349166991710663 0.44520143988311933
49 48 0.8421041666666667 0.8360215053763441 0.43520630804697674 0.44510498354511874
50 49 0.8416875 0.836861559139785 0.4354708949327469 0.44491737311886204
51 50 0.8408125 0.836861559139785 0.43524290529886883 0.4449239748139535

BIN
data/notebook2/vgg_best.pt Normal file

Binary file not shown.

View File

@ -0,0 +1,51 @@
epoch,train_acc,valid_acc,train_loss,valid_loss
1,0.2535833333333333,0.41952284946236557,2.274448096593221,2.2001695325297694
2,0.4691875,0.5073084677419355,2.116311640103658,2.051948335862929
3,0.5270416666666666,0.5266297043010753,2.009436710357666,1.9777893276624783
4,0.5353541666666667,0.535114247311828,1.9530798486073813,1.9356736265203005
5,0.5935416666666666,0.6175235215053764,1.9118746633529664,1.891348486305565
6,0.636625,0.6625504032258065,1.8629151741663614,1.847881397893352
7,0.686625,0.7302587365591398,1.8228577140172322,1.8057919330494379
8,0.7415416666666667,0.7511760752688172,1.7803055089314779,1.7647134245082896
9,0.7542291666666666,0.7582325268817204,1.7502256110509236,1.7423059594246648
10,0.7607708333333333,0.764616935483871,1.7317150996526083,1.724781782396378
11,0.7677291666666667,0.7719254032258065,1.7192926041285197,1.713816902970755
12,0.7704791666666667,0.7731014784946236,1.71117463239034,1.7070557173862253
13,0.771625,0.7736055107526881,1.7053865385055542,1.7027398668309695
14,0.7747708333333333,0.7749495967741935,1.7004834445317587,1.6978983327906618
15,0.7763333333333333,0.7751176075268817,1.6967166414260864,1.6960316857983988
16,0.7790416666666666,0.7767977150537635,1.6929608987172444,1.692371742699736
17,0.780125,0.7801579301075269,1.6910814901987712,1.6917471321680213
18,0.7817916666666667,0.7827620967741935,1.6882368882497152,1.6861866725388395
19,0.7827291666666667,0.7815860215053764,1.6857908573150635,1.6868788478195027
20,0.7827083333333333,0.7840221774193549,1.6843910643259685,1.6850647259784002
21,0.7848541666666666,0.7857862903225806,1.6823169787724812,1.6813292862266622
22,0.7876666666666666,0.7848622311827957,1.6801685171127319,1.6815461574062225
23,0.7879791666666667,0.7848622311827957,1.6788572374979656,1.6818229216401295
24,0.7890833333333334,0.7859543010752689,1.6776897859573365,1.6799169009731663
25,0.7902708333333334,0.7883904569892473,1.6763870159784953,1.6777271570697907
26,0.791,0.789986559139785,1.6751194190979004,1.6777111407249206
27,0.7918541666666666,0.7886424731182796,1.674490571975708,1.6776156912567795
28,0.7936666666666666,0.7905745967741935,1.6728278172810873,1.6744540147883917
29,0.793375,0.7876344086021505,1.671903525352478,1.6776354030896259
30,0.7945833333333333,0.7901545698924731,1.6710031661987306,1.6751063805754467
31,0.7959166666666667,0.7932627688172043,1.6700964994430543,1.6719800926023913
32,0.7955833333333333,0.7933467741935484,1.6697869466145834,1.6723043623790945
33,0.7971458333333333,0.792002688172043,1.668815224647522,1.6722341519530102
34,0.7974375,0.7931787634408602,1.6682115157445272,1.6723163063808153
35,0.7979791666666667,0.7947748655913979,1.6676976483662924,1.670254216399244
36,0.7984583333333334,0.7950268817204301,1.6669070380528768,1.6704239499184392
37,0.799,0.795866935483871,1.6664897476832072,1.6691849411174815
38,0.7991875,0.7948588709677419,1.6660230029424032,1.6700230785595473
39,0.800125,0.795950940860215,1.6655609397888183,1.6688298884258475
40,0.8006458333333333,0.7956989247311828,1.6652250595092772,1.6690674840763051
41,0.8006875,0.7962869623655914,1.664971160888672,1.6688041135828982
42,0.8015416666666667,0.795866935483871,1.6647087678909303,1.668453808753721
43,0.8013958333333333,0.7962029569892473,1.6643381754557292,1.6682484457569737
44,0.8016458333333333,0.7965389784946236,1.6641663211186728,1.668392535178892
45,0.8017708333333333,0.7957829301075269,1.6640429185231527,1.6681856429705055
46,0.8021041666666666,0.7965389784946236,1.663849822362264,1.6680005634984663
47,0.8024583333333334,0.7966229838709677,1.663770879427592,1.6679569816076627
48,0.8025833333333333,0.7962029569892473,1.6636734215418498,1.6680130330465173
49,0.8024583333333334,0.7970430107526881,1.6635978577931723,1.6679544077124646
50,0.8025,0.7970430107526881,1.6635685895284016,1.6679527682642783
1 epoch train_acc valid_acc train_loss valid_loss
2 1 0.2535833333333333 0.41952284946236557 2.274448096593221 2.2001695325297694
3 2 0.4691875 0.5073084677419355 2.116311640103658 2.051948335862929
4 3 0.5270416666666666 0.5266297043010753 2.009436710357666 1.9777893276624783
5 4 0.5353541666666667 0.535114247311828 1.9530798486073813 1.9356736265203005
6 5 0.5935416666666666 0.6175235215053764 1.9118746633529664 1.891348486305565
7 6 0.636625 0.6625504032258065 1.8629151741663614 1.847881397893352
8 7 0.686625 0.7302587365591398 1.8228577140172322 1.8057919330494379
9 8 0.7415416666666667 0.7511760752688172 1.7803055089314779 1.7647134245082896
10 9 0.7542291666666666 0.7582325268817204 1.7502256110509236 1.7423059594246648
11 10 0.7607708333333333 0.764616935483871 1.7317150996526083 1.724781782396378
12 11 0.7677291666666667 0.7719254032258065 1.7192926041285197 1.713816902970755
13 12 0.7704791666666667 0.7731014784946236 1.71117463239034 1.7070557173862253
14 13 0.771625 0.7736055107526881 1.7053865385055542 1.7027398668309695
15 14 0.7747708333333333 0.7749495967741935 1.7004834445317587 1.6978983327906618
16 15 0.7763333333333333 0.7751176075268817 1.6967166414260864 1.6960316857983988
17 16 0.7790416666666666 0.7767977150537635 1.6929608987172444 1.692371742699736
18 17 0.780125 0.7801579301075269 1.6910814901987712 1.6917471321680213
19 18 0.7817916666666667 0.7827620967741935 1.6882368882497152 1.6861866725388395
20 19 0.7827291666666667 0.7815860215053764 1.6857908573150635 1.6868788478195027
21 20 0.7827083333333333 0.7840221774193549 1.6843910643259685 1.6850647259784002
22 21 0.7848541666666666 0.7857862903225806 1.6823169787724812 1.6813292862266622
23 22 0.7876666666666666 0.7848622311827957 1.6801685171127319 1.6815461574062225
24 23 0.7879791666666667 0.7848622311827957 1.6788572374979656 1.6818229216401295
25 24 0.7890833333333334 0.7859543010752689 1.6776897859573365 1.6799169009731663
26 25 0.7902708333333334 0.7883904569892473 1.6763870159784953 1.6777271570697907
27 26 0.791 0.789986559139785 1.6751194190979004 1.6777111407249206
28 27 0.7918541666666666 0.7886424731182796 1.674490571975708 1.6776156912567795
29 28 0.7936666666666666 0.7905745967741935 1.6728278172810873 1.6744540147883917
30 29 0.793375 0.7876344086021505 1.671903525352478 1.6776354030896259
31 30 0.7945833333333333 0.7901545698924731 1.6710031661987306 1.6751063805754467
32 31 0.7959166666666667 0.7932627688172043 1.6700964994430543 1.6719800926023913
33 32 0.7955833333333333 0.7933467741935484 1.6697869466145834 1.6723043623790945
34 33 0.7971458333333333 0.792002688172043 1.668815224647522 1.6722341519530102
35 34 0.7974375 0.7931787634408602 1.6682115157445272 1.6723163063808153
36 35 0.7979791666666667 0.7947748655913979 1.6676976483662924 1.670254216399244
37 36 0.7984583333333334 0.7950268817204301 1.6669070380528768 1.6704239499184392
38 37 0.799 0.795866935483871 1.6664897476832072 1.6691849411174815
39 38 0.7991875 0.7948588709677419 1.6660230029424032 1.6700230785595473
40 39 0.800125 0.795950940860215 1.6655609397888183 1.6688298884258475
41 40 0.8006458333333333 0.7956989247311828 1.6652250595092772 1.6690674840763051
42 41 0.8006875 0.7962869623655914 1.664971160888672 1.6688041135828982
43 42 0.8015416666666667 0.795866935483871 1.6647087678909303 1.668453808753721
44 43 0.8013958333333333 0.7962029569892473 1.6643381754557292 1.6682484457569737
45 44 0.8016458333333333 0.7965389784946236 1.6641663211186728 1.668392535178892
46 45 0.8017708333333333 0.7957829301075269 1.6640429185231527 1.6681856429705055
47 46 0.8021041666666666 0.7965389784946236 1.663849822362264 1.6680005634984663
48 47 0.8024583333333334 0.7966229838709677 1.663770879427592 1.6679569816076627
49 48 0.8025833333333333 0.7962029569892473 1.6636734215418498 1.6680130330465173
50 49 0.8024583333333334 0.7970430107526881 1.6635978577931723 1.6679544077124646
51 50 0.8025 0.7970430107526881 1.6635685895284016 1.6679527682642783