refactor(Modify): 重构 Modify.py 脚本
- 添加数据增强、优化器与学习率调度等改进 - 统一训练框架,支持多模型训练- 增加梯度裁剪、最佳模型保存等功能 - 优化代码结构,提高可维护性
This commit is contained in:
parent
b1818286f0
commit
a6c92a4031
696
Modify.ipynb
Normal file
696
Modify.ipynb
Normal file
File diff suppressed because one or more lines are too long
291
Modify.py
Normal file
291
Modify.py
Normal file
@ -0,0 +1,291 @@
|
||||
# Modify.py
|
||||
|
||||
#%% 导入所有需要的包
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import deepquantum as dq
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.datasets import FashionMNIST
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
from multiprocessing import freeze_support
|
||||
|
||||
#%% 设置随机种子以保证可复现
|
||||
def seed_torch(seed=1024):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
#%% 准确率计算函数
|
||||
def calculate_score(y_true, y_preds):
|
||||
preds_prob = torch.softmax(y_preds, dim=1)
|
||||
preds_class = torch.argmax(preds_prob, dim=1)
|
||||
correct = (preds_class == y_true).float()
|
||||
return (correct.sum() / len(correct)).cpu().numpy()
|
||||
|
||||
#%% 训练与验证函数
|
||||
def train_model(model, criterion, optimizer, scheduler, train_loader, valid_loader, num_epochs, device, save_path):
|
||||
model.to(device)
|
||||
best_acc = 0.0
|
||||
metrics = {'epoch': [], 'train_acc': [], 'valid_acc': [], 'train_loss': [], 'valid_loss': []}
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
# --- 训练阶段 ---
|
||||
model.train()
|
||||
running_loss, running_acc = 0.0, 0.0
|
||||
for imgs, labels in train_loader:
|
||||
imgs, labels = imgs.to(device), labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(imgs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
running_acc += calculate_score(labels, outputs)
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = running_acc / len(train_loader)
|
||||
scheduler.step()
|
||||
|
||||
# --- 验证阶段 ---
|
||||
model.eval()
|
||||
val_loss, val_acc = 0.0, 0.0
|
||||
with torch.no_grad():
|
||||
for imgs, labels in valid_loader:
|
||||
imgs, labels = imgs.to(device), labels.to(device)
|
||||
outputs = model(imgs)
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.item()
|
||||
val_acc += calculate_score(labels, outputs)
|
||||
|
||||
valid_loss = val_loss / len(valid_loader)
|
||||
valid_acc = val_acc / len(valid_loader)
|
||||
|
||||
metrics['epoch'].append(epoch)
|
||||
metrics['train_loss'].append(train_loss)
|
||||
metrics['valid_loss'].append(valid_loss)
|
||||
metrics['train_acc'].append(train_acc)
|
||||
metrics['valid_acc'].append(valid_acc)
|
||||
|
||||
tqdm.write(f"[{save_path}] Epoch {epoch}/{num_epochs} "
|
||||
f"Train Acc: {train_acc:.4f} Valid Acc: {valid_acc:.4f}")
|
||||
|
||||
if valid_acc > best_acc:
|
||||
best_acc = valid_acc
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
||||
return model, metrics
|
||||
|
||||
#%% 测试函数
|
||||
def test_model(model, test_loader, device):
|
||||
model.to(device).eval()
|
||||
acc = 0.0
|
||||
with torch.no_grad():
|
||||
for imgs, labels in test_loader:
|
||||
imgs, labels = imgs.to(device), labels.to(device)
|
||||
outputs = model(imgs)
|
||||
acc += calculate_score(labels, outputs)
|
||||
acc /= len(test_loader)
|
||||
print(f"Test Accuracy: {acc:.4f}")
|
||||
return acc
|
||||
|
||||
#%% 定义量子卷积层与模型
|
||||
singlegate_list = ['rx','ry','rz','s','t','p','u3']
|
||||
doublegate_list = ['rxx','ryy','rzz','swap','cnot','cp','ch','cu','ct','cz']
|
||||
|
||||
class RandomQuantumConvolutionalLayer(nn.Module):
|
||||
def __init__(self, nqubit, num_circuits, seed=1024):
|
||||
super().__init__()
|
||||
random.seed(seed)
|
||||
self.nqubit = nqubit
|
||||
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||
def circuit(self, nqubit):
|
||||
cir = dq.QubitCircuit(nqubit)
|
||||
cir.rxlayer(encode=True); cir.barrier()
|
||||
for _ in range(3):
|
||||
for i in range(nqubit):
|
||||
getattr(cir, random.choice(singlegate_list))(i)
|
||||
c,t = random.sample(range(nqubit),2)
|
||||
gate = random.choice(doublegate_list)
|
||||
if gate[0] in ['r','s']:
|
||||
getattr(cir, gate)([c,t])
|
||||
else:
|
||||
getattr(cir, gate)(c,t)
|
||||
cir.barrier()
|
||||
cir.observable(0)
|
||||
return cir
|
||||
def forward(self, x):
|
||||
k,s = 2,2
|
||||
x_unf = x.unfold(2,k,s).unfold(3,k,s)
|
||||
w = (x.shape[-1]-k)//s + 1
|
||||
x_r = x_unf.reshape(-1, self.nqubit)
|
||||
exps = []
|
||||
for cir in self.cirs:
|
||||
cir(x_r)
|
||||
exps.append(cir.expectation())
|
||||
exps = torch.stack(exps,1).reshape(x.size(0), len(self.cirs), w, w)
|
||||
return exps
|
||||
|
||||
class RandomQCCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
RandomQuantumConvolutionalLayer(4,3,seed=1024),
|
||||
nn.ReLU(), nn.MaxPool2d(2,1),
|
||||
nn.Conv2d(3,6,2,1), nn.ReLU(), nn.MaxPool2d(2,1)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(6*6*6,1024), nn.Dropout(0.4),
|
||||
nn.Linear(1024,10)
|
||||
)
|
||||
def forward(self,x):
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0),-1)
|
||||
return self.fc(x)
|
||||
|
||||
class ParameterizedQuantumConvolutionalLayer(nn.Module):
|
||||
def __init__(self,nqubit,num_circuits):
|
||||
super().__init__()
|
||||
self.nqubit = nqubit
|
||||
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||
def circuit(self,nqubit):
|
||||
cir = dq.QubitCircuit(nqubit)
|
||||
cir.rxlayer(encode=True); cir.barrier()
|
||||
for _ in range(4):
|
||||
cir.rylayer(); cir.cnot_ring(); cir.barrier()
|
||||
cir.observable(0)
|
||||
return cir
|
||||
def forward(self,x):
|
||||
k,s = 2,2
|
||||
x_unf = x.unfold(2,k,s).unfold(3,k,s)
|
||||
w = (x.shape[-1]-k)//s +1
|
||||
x_r = x_unf.reshape(-1,self.nqubit)
|
||||
exps = []
|
||||
for cir in self.cirs:
|
||||
cir(x_r); exps.append(cir.expectation())
|
||||
exps = torch.stack(exps,1).reshape(x.size(0),len(self.cirs),w,w)
|
||||
return exps
|
||||
|
||||
class QCCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ParameterizedQuantumConvolutionalLayer(4,3),
|
||||
nn.ReLU(), nn.MaxPool2d(2,1)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(8*8*3,128), nn.Dropout(0.4), nn.ReLU(),
|
||||
nn.Linear(128,10)
|
||||
)
|
||||
def forward(self,x):
|
||||
x = self.conv(x); x = x.view(x.size(0),-1)
|
||||
return self.fc(x)
|
||||
|
||||
def vgg_block(in_c,out_c,n_convs):
|
||||
layers = [nn.Conv2d(in_c,out_c,3,padding=1), nn.ReLU()]
|
||||
for _ in range(n_convs-1):
|
||||
layers += [nn.Conv2d(out_c,out_c,3,padding=1), nn.ReLU()]
|
||||
layers.append(nn.MaxPool2d(2,2))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
VGG = nn.Sequential(
|
||||
vgg_block(1,10,3),
|
||||
vgg_block(10,16,3),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16*4*4,120), nn.Sigmoid(),
|
||||
nn.Linear(120,84), nn.Sigmoid(),
|
||||
nn.Linear(84,10), nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
#%% 主入口
|
||||
if __name__ == '__main__':
|
||||
freeze_support()
|
||||
|
||||
# 数据增广与加载
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((18, 18)),
|
||||
transforms.RandomRotation(15),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomVerticalFlip(0.3),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
eval_transform = transforms.Compose([
|
||||
transforms.Resize((18, 18)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
|
||||
full_train = FashionMNIST(root='./data/notebook2', train=True, transform=train_transform, download=True)
|
||||
test_dataset = FashionMNIST(root='./data/notebook2', train=False, transform=eval_transform, download=True)
|
||||
train_size = int(0.8 * len(full_train))
|
||||
valid_size = len(full_train) - train_size
|
||||
train_ds, valid_ds = torch.utils.data.random_split(full_train, [train_size, valid_size])
|
||||
valid_ds.dataset.transform = eval_transform
|
||||
|
||||
batch_size = 128
|
||||
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=4)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
|
||||
|
||||
# 三种模型配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
models = {
|
||||
'random_qccnn': (RandomQCCNN(), 1e-3, './data/notebook2/random_qccnn_best.pt'),
|
||||
'qccnn': (QCCNN(), 1e-4, './data/notebook2/qccnn_best.pt'),
|
||||
'vgg': (VGG, 1e-4, './data/notebook2/vgg_best.pt')
|
||||
}
|
||||
|
||||
all_metrics = {}
|
||||
for name, (model, lr, save_path) in models.items():
|
||||
seed_torch(1024)
|
||||
model = model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
|
||||
|
||||
print(f"\n=== Training {name} ===")
|
||||
_, metrics = train_model(
|
||||
model, criterion, optimizer, scheduler,
|
||||
train_loader, valid_loader,
|
||||
num_epochs=50, device=device, save_path=save_path
|
||||
)
|
||||
all_metrics[name] = metrics
|
||||
pd.DataFrame(metrics).to_csv(f'./data/notebook2/{name}_metrics.csv', index=False)
|
||||
|
||||
# 测试与可视化
|
||||
plt.figure(figsize=(12,5))
|
||||
for i,(name,metrics) in enumerate(all_metrics.items(),1):
|
||||
model, _, save_path = models[name]
|
||||
best_model = model.to(device)
|
||||
best_model.load_state_dict(torch.load(save_path))
|
||||
print(f"\n--- Testing {name} ---")
|
||||
test_model(best_model, test_loader, device)
|
||||
|
||||
plt.subplot(1,3,i)
|
||||
plt.plot(metrics['epoch'], metrics['valid_acc'], label=f'{name} Val Acc')
|
||||
plt.xlabel('Epoch'); plt.ylabel('Valid Acc')
|
||||
plt.title(name); plt.legend()
|
||||
|
||||
plt.tight_layout(); plt.show()
|
||||
|
||||
# 参数量统计
|
||||
def count_parameters(m):
|
||||
return sum(p.numel() for p in m.parameters() if p.requires_grad)
|
||||
|
||||
print("\nParameter Counts:")
|
||||
for name,(model,_,_) in models.items():
|
||||
print(f"{name}: {count_parameters(model)}")
|
460
Origin.py
Normal file
460
Origin.py
Normal file
@ -0,0 +1,460 @@
|
||||
#%%
|
||||
# 首先我们导入所有需要的包:
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import deepquantum as dq
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from torch.utils.data import DataLoader
|
||||
# from torchvision.datasets import MNIST, FashionMNIST
|
||||
|
||||
def seed_torch(seed=1024):
|
||||
"""
|
||||
Set random seeds for reproducibility.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed number to use. Default is 1024.
|
||||
"""
|
||||
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Seed all GPUs with the same seed if using multi-GPU
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
seed_torch(1024)
|
||||
#%%
|
||||
def calculate_score(y_true, y_preds):
|
||||
# 将模型预测结果转为概率分布
|
||||
preds_prob = torch.softmax(y_preds, dim=1)
|
||||
# 获得预测的类别(概率最高的一类)
|
||||
preds_class = torch.argmax(preds_prob, dim=1)
|
||||
# 计算准确率
|
||||
correct = (preds_class == y_true).float()
|
||||
accuracy = correct.sum() / len(correct)
|
||||
return accuracy.cpu().numpy()
|
||||
|
||||
|
||||
def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device):
|
||||
"""
|
||||
训练和验证模型。
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): 要训练的模型。
|
||||
criterion (torch.nn.Module): 损失函数。
|
||||
optimizer (torch.optim.Optimizer): 优化器。
|
||||
train_loader (torch.utils.data.DataLoader): 训练数据加载器。
|
||||
valid_loader (torch.utils.data.DataLoader): 验证数据加载器。
|
||||
num_epochs (int): 训练的epoch数。
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): 训练后的模型。
|
||||
"""
|
||||
|
||||
model.train()
|
||||
train_loss_list = []
|
||||
valid_loss_list = []
|
||||
train_acc_list = []
|
||||
valid_acc_list = []
|
||||
|
||||
with tqdm(total=num_epochs) as pbar:
|
||||
for epoch in range(num_epochs):
|
||||
# 训练阶段
|
||||
train_loss = 0.0
|
||||
train_acc = 0.0
|
||||
for images, labels in train_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
train_acc += calculate_score(labels, outputs)
|
||||
|
||||
train_loss /= len(train_loader)
|
||||
train_acc /= len(train_loader)
|
||||
|
||||
# 验证阶段
|
||||
model.eval()
|
||||
valid_loss = 0.0
|
||||
valid_acc = 0.0
|
||||
with torch.no_grad():
|
||||
for images, labels in valid_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
valid_loss += loss.item()
|
||||
valid_acc += calculate_score(labels, outputs)
|
||||
|
||||
valid_loss /= len(valid_loader)
|
||||
valid_acc /= len(valid_loader)
|
||||
|
||||
pbar.set_description(f"Train loss: {train_loss:.3f} Valid Acc: {valid_acc:.3f}")
|
||||
pbar.update()
|
||||
|
||||
|
||||
train_loss_list.append(train_loss)
|
||||
valid_loss_list.append(valid_loss)
|
||||
train_acc_list.append(train_acc)
|
||||
valid_acc_list.append(valid_acc)
|
||||
|
||||
metrics = {'epoch': list(range(1, num_epochs + 1)),
|
||||
'train_acc': train_acc_list,
|
||||
'valid_acc': valid_acc_list,
|
||||
'train_loss': train_loss_list,
|
||||
'valid_loss': valid_loss_list}
|
||||
|
||||
|
||||
|
||||
return model, metrics
|
||||
|
||||
def test_model(model, test_loader, device):
|
||||
model.eval()
|
||||
test_acc = 0.0
|
||||
with torch.no_grad():
|
||||
for images, labels in test_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model(images)
|
||||
test_acc += calculate_score(labels, outputs)
|
||||
|
||||
test_acc /= len(test_loader)
|
||||
print(f'Test Acc: {test_acc:.3f}')
|
||||
return test_acc
|
||||
#%%
|
||||
# 定义图像变换
|
||||
trans1 = transforms.Compose([
|
||||
transforms.Resize((18, 18)), # 调整大小为18x18
|
||||
transforms.ToTensor() # 转换为张量
|
||||
])
|
||||
|
||||
trans2 = transforms.Compose([
|
||||
transforms.Resize((16, 16)), # 调整大小为16x16
|
||||
transforms.ToTensor() # 转换为张量
|
||||
])
|
||||
train_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
|
||||
test_dataset = FashionMNIST(root='./data/notebook1', train=False, transform=trans1,download=True)
|
||||
|
||||
# 定义训练集和测试集的比例
|
||||
train_ratio = 0.8 # 训练集比例为80%,验证集比例为20%
|
||||
valid_ratio = 0.2
|
||||
total_samples = len(train_dataset)
|
||||
train_size = int(train_ratio * total_samples)
|
||||
valid_size = int(valid_ratio * total_samples)
|
||||
|
||||
# 分割训练集和测试集
|
||||
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])
|
||||
|
||||
# 加载随机抽取的训练数据集
|
||||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, drop_last=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)
|
||||
#%%
|
||||
singlegate_list = ['rx', 'ry', 'rz', 's', 't', 'p', 'u3']
|
||||
doublegate_list = ['rxx', 'ryy', 'rzz', 'swap', 'cnot', 'cp', 'ch', 'cu', 'ct', 'cz']
|
||||
#%%
|
||||
# 随机量子卷积层
|
||||
class RandomQuantumConvolutionalLayer(nn.Module):
|
||||
def __init__(self, nqubit, num_circuits, seed:int=1024):
|
||||
super(RandomQuantumConvolutionalLayer, self).__init__()
|
||||
random.seed(seed)
|
||||
self.nqubit = nqubit
|
||||
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||
|
||||
def circuit(self, nqubit):
|
||||
cir = dq.QubitCircuit(nqubit)
|
||||
cir.rxlayer(encode=True) # 对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
|
||||
cir.barrier()
|
||||
for iter in range(3):
|
||||
for i in range(nqubit):
|
||||
singlegate = random.choice(singlegate_list)
|
||||
getattr(cir, singlegate)(i)
|
||||
control_bit, target_bit = random.sample(range(0, nqubit - 1), 2)
|
||||
doublegate = random.choice(doublegate_list)
|
||||
if doublegate[0] in ['r', 's']:
|
||||
getattr(cir, doublegate)([control_bit, target_bit])
|
||||
else:
|
||||
getattr(cir, doublegate)(control_bit, target_bit)
|
||||
cir.barrier()
|
||||
|
||||
cir.observable(0)
|
||||
return cir
|
||||
|
||||
def forward(self, x):
|
||||
kernel_size, stride = 2, 2
|
||||
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
|
||||
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
|
||||
w = int((x.shape[-1] - kernel_size) / stride + 1)
|
||||
x_reshape = x_unflod.reshape(-1, self.nqubit)
|
||||
|
||||
exps = []
|
||||
for cir in self.cirs: # out_channels
|
||||
cir(x_reshape)
|
||||
exp = cir.expectation()
|
||||
exps.append(exp)
|
||||
|
||||
exps = torch.stack(exps, dim=1)
|
||||
exps = exps.reshape(x.shape[0], 3, w, w)
|
||||
return exps
|
||||
#%%
|
||||
net = RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024)
|
||||
net.cirs[0].draw()
|
||||
#%%
|
||||
# 基于随机量子卷积层的混合模型
|
||||
class RandomQCCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(RandomQCCNN, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
RandomQuantumConvolutionalLayer(nqubit=4, num_circuits=3, seed=1024), # num_circuits=3代表我们在quanv1层只用了3个量子卷积核
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1),
|
||||
nn.Conv2d(3, 6, kernel_size=2, stride=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(6 * 6 * 6, 1024),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(1024, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.reshape(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
#%%
|
||||
num_epochs = 300
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
seed_torch(1024) # 重新设置随机种子
|
||||
model = RandomQCCNN()
|
||||
model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加正则化项
|
||||
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||
torch.save(optim_model.state_dict(), './data/notebook1/random_qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||
pd.DataFrame(metrics).to_csv('./data/notebook1/random_qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||
#%%
|
||||
state_dict = torch.load('./data/notebook1/random_qccnn_weights.pt', map_location=device)
|
||||
random_qccnn_model = RandomQCCNN()
|
||||
random_qccnn_model.load_state_dict(state_dict)
|
||||
random_qccnn_model.to(device)
|
||||
|
||||
test_acc = test_model(random_qccnn_model, test_loader, device)
|
||||
#%%
|
||||
data = pd.read_csv('./data/notebook1/random_qccnn_metrics.csv')
|
||||
epoch = data['epoch']
|
||||
train_loss = data['train_loss']
|
||||
valid_loss = data['valid_loss']
|
||||
train_acc = data['train_acc']
|
||||
valid_acc = data['valid_acc']
|
||||
|
||||
# 创建图和Axes对象
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
||||
|
||||
# 绘制训练损失曲线
|
||||
ax1.plot(epoch, train_loss, label='Train Loss')
|
||||
ax1.plot(epoch, valid_loss, label='Valid Loss')
|
||||
ax1.set_title('Training Loss Curve')
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.legend()
|
||||
|
||||
# 绘制训练准确率曲线
|
||||
ax2.plot(epoch, train_acc, label='Train Accuracy')
|
||||
ax2.plot(epoch, valid_acc, label='Valid Accuracy')
|
||||
ax2.set_title('Training Accuracy Curve')
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Accuracy')
|
||||
ax2.legend()
|
||||
|
||||
plt.show()
|
||||
#%%
|
||||
class ParameterizedQuantumConvolutionalLayer(nn.Module):
|
||||
def __init__(self, nqubit, num_circuits):
|
||||
super().__init__()
|
||||
self.nqubit = nqubit
|
||||
self.cirs = nn.ModuleList([self.circuit(nqubit) for _ in range(num_circuits)])
|
||||
|
||||
def circuit(self, nqubit):
|
||||
cir = dq.QubitCircuit(nqubit)
|
||||
cir.rxlayer(encode=True) #对原论文的量子线路结构并无影响,只是做了一个数据编码的操作
|
||||
cir.barrier()
|
||||
for iter in range(4): #对应原论文中一个量子卷积线路上的深度为4,可控参数一共16个
|
||||
cir.rylayer()
|
||||
cir.cnot_ring()
|
||||
cir.barrier()
|
||||
|
||||
cir.observable(0)
|
||||
return cir
|
||||
|
||||
def forward(self, x):
|
||||
kernel_size, stride = 2, 2
|
||||
# [64, 1, 18, 18] -> [64, 1, 9, 18, 2] -> [64, 1, 9, 9, 2, 2]
|
||||
x_unflod = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
|
||||
w = int((x.shape[-1] - kernel_size) / stride + 1)
|
||||
x_reshape = x_unflod.reshape(-1, self.nqubit)
|
||||
|
||||
exps = []
|
||||
for cir in self.cirs: # out_channels
|
||||
cir(x_reshape)
|
||||
exp = cir.expectation()
|
||||
exps.append(exp)
|
||||
|
||||
exps = torch.stack(exps, dim=1)
|
||||
exps = exps.reshape(x.shape[0], 3, w, w)
|
||||
return exps
|
||||
#%%
|
||||
# 此处我们可视化其中一个量子卷积核的线路结构:
|
||||
net = ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3)
|
||||
net.cirs[0].draw()
|
||||
#%%
|
||||
# QCCNN整体网络架构:
|
||||
class QCCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(QCCNN, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ParameterizedQuantumConvolutionalLayer(nqubit=4, num_circuits=3),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1)
|
||||
)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(8 * 8 * 3, 128),
|
||||
nn.Dropout(0.4),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.reshape(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
#%%
|
||||
num_epochs = 300
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = QCCNN()
|
||||
model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-5) # 添加正则化项
|
||||
optim_model, metrics = train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||
torch.save(optim_model.state_dict(), './data/notebook1/qccnn_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||
pd.DataFrame(metrics).to_csv('./data/notebook1/qccnn_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||
#%%
|
||||
state_dict = torch.load('./data/notebook1/qccnn_weights.pt', map_location=device)
|
||||
qccnn_model = QCCNN()
|
||||
qccnn_model.load_state_dict(state_dict)
|
||||
qccnn_model.to(device)
|
||||
|
||||
test_acc = test_model(qccnn_model, test_loader, device)
|
||||
#%%
|
||||
def vgg_block(in_channel,out_channel,num_convs):
|
||||
layers = nn.ModuleList()
|
||||
assert num_convs >= 1
|
||||
layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1))
|
||||
layers.append(nn.ReLU())
|
||||
for _ in range(num_convs-1):
|
||||
layers.append(nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
VGG = nn.Sequential(
|
||||
vgg_block(1,10,3), # 14,14
|
||||
vgg_block(10,16,3), # 4 * 4
|
||||
nn.Flatten(),
|
||||
nn.Linear(16 * 4 * 4, 120),
|
||||
nn.Sigmoid(),
|
||||
nn.Linear(120, 84),
|
||||
nn.Sigmoid(),
|
||||
nn.Linear(84,10),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
#%%
|
||||
num_epochs = 300
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
vgg_model = VGG
|
||||
vgg_model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(vgg_model.parameters(), lr=1e-5) # 添加正则化项
|
||||
vgg_model, metrics = train_model(vgg_model, criterion, optimizer, train_loader, valid_loader, num_epochs, device)
|
||||
torch.save(vgg_model.state_dict(), './data/notebook1/vgg_weights.pt') # 保存训练好的模型参数,用于后续的推理或测试
|
||||
pd.DataFrame(metrics).to_csv('./data/notebook1/vgg_metrics.csv', index='None') # 保存模型训练过程,用于后续图标展示
|
||||
#%%
|
||||
state_dict = torch.load('./data/notebook1/vgg_weights.pt', map_location=device)
|
||||
vgg_model = VGG
|
||||
vgg_model.load_state_dict(state_dict)
|
||||
vgg_model.to(device)
|
||||
|
||||
vgg_test_acc = test_model(vgg_model, test_loader, device)
|
||||
#%%
|
||||
vgg_data = pd.read_csv('./data/notebook1/vgg_metrics.csv')
|
||||
qccnn_data = pd.read_csv('./data/notebook1/qccnn_metrics.csv')
|
||||
vgg_epoch = vgg_data['epoch']
|
||||
vgg_train_loss = vgg_data['train_loss']
|
||||
vgg_valid_loss = vgg_data['valid_loss']
|
||||
vgg_train_acc = vgg_data['train_acc']
|
||||
vgg_valid_acc = vgg_data['valid_acc']
|
||||
|
||||
qccnn_epoch = qccnn_data['epoch']
|
||||
qccnn_train_loss = qccnn_data['train_loss']
|
||||
qccnn_valid_loss = qccnn_data['valid_loss']
|
||||
qccnn_train_acc = qccnn_data['train_acc']
|
||||
qccnn_valid_acc = qccnn_data['valid_acc']
|
||||
|
||||
# 创建图和Axes对象
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
||||
|
||||
# 绘制训练损失曲线
|
||||
ax1.plot(vgg_epoch, vgg_train_loss, label='VGG Train Loss')
|
||||
ax1.plot(vgg_epoch, vgg_valid_loss, label='VGG Valid Loss')
|
||||
ax1.plot(qccnn_epoch, qccnn_train_loss, label='QCCNN Valid Loss')
|
||||
ax1.plot(qccnn_epoch, qccnn_valid_loss, label='QCCNN Valid Loss')
|
||||
ax1.set_title('Training Loss Curve')
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.legend()
|
||||
|
||||
# 绘制训练准确率曲线
|
||||
ax2.plot(vgg_epoch, vgg_train_acc, label='VGG Train Accuracy')
|
||||
ax2.plot(vgg_epoch, vgg_valid_acc, label='VGG Valid Accuracy')
|
||||
ax2.plot(qccnn_epoch, qccnn_train_acc, label='QCCNN Train Accuracy')
|
||||
ax2.plot(qccnn_epoch, qccnn_valid_acc, label='QCCNN Valid Accuracy')
|
||||
ax2.set_title('Training Accuracy Curve')
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Accuracy')
|
||||
ax2.legend()
|
||||
|
||||
plt.show()
|
||||
#%%
|
||||
# 这里我们对比不同模型之间可训练参数量的区别
|
||||
|
||||
def count_parameters(model):
|
||||
"""
|
||||
计算模型的参数数量
|
||||
"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
number_params_VGG = count_parameters(VGG)
|
||||
number_params_QCCNN = count_parameters(QCCNN())
|
||||
print(f'VGG 模型可训练参数量:{number_params_VGG}\t QCCNN模型可训练参数量:{number_params_QCCNN}')
|
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
data/notebook2/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
data/notebook2/qccnn_best.pt
Normal file
BIN
data/notebook2/qccnn_best.pt
Normal file
Binary file not shown.
51
data/notebook2/qccnn_metrics.csv
Normal file
51
data/notebook2/qccnn_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
||||
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||
1,0.3112916666666667,0.5076444892473119,2.1129255460103353,1.7863053237238238
|
||||
2,0.46210416666666665,0.5547715053763441,1.5605127541224162,1.3499009532313193
|
||||
3,0.5200833333333333,0.5892137096774194,1.298180630683899,1.1617528520604616
|
||||
4,0.5641666666666667,0.6311323924731183,1.161390997727712,1.0419180438082705
|
||||
5,0.6055833333333334,0.6598622311827957,1.062113331158956,0.9560088053826363
|
||||
6,0.6316666666666667,0.6804435483870968,0.9944041889508565,0.8958990977656457
|
||||
7,0.6513958333333333,0.6928763440860215,0.9395476338068645,0.8504296951396491
|
||||
8,0.6677708333333333,0.7007728494623656,0.8985106336275737,0.8148919709267155
|
||||
9,0.6809375,0.7095934139784946,0.8673048818906148,0.7897684433126962
|
||||
10,0.6876666666666666,0.7185819892473119,0.8408271819750468,0.7671139176173877
|
||||
11,0.697625,0.7247143817204301,0.8202435002326965,0.7502269674372929
|
||||
12,0.7030625,0.7303427419354839,0.8024592914581299,0.7343912464316174
|
||||
13,0.7118958333333333,0.7316868279569892,0.7869214735031128,0.7222210591839205
|
||||
14,0.7163958333333333,0.7404233870967742,0.7718908907572428,0.7093277560767307
|
||||
15,0.7211041666666667,0.7437836021505376,0.7589795832633972,0.6981546148177116
|
||||
16,0.724875,0.748067876344086,0.7483940289815267,0.6881407947950465
|
||||
17,0.7293958333333334,0.75,0.7373875479698181,0.6788085179944192
|
||||
18,0.7345,0.7517641129032258,0.7301613558133443,0.6724013622089099
|
||||
19,0.7350416666666667,0.7550403225806451,0.7210039290587107,0.6651292238184201
|
||||
20,0.7391458333333333,0.758736559139785,0.711990216255188,0.6584682823509298
|
||||
21,0.7433958333333334,0.7608366935483871,0.7047773049672444,0.6513981505106854
|
||||
22,0.7442916666666667,0.7634408602150538,0.6997552577654521,0.6454382868864204
|
||||
23,0.7481458333333333,0.7653729838709677,0.6937046423753103,0.6401395256160408
|
||||
24,0.7498333333333334,0.7683131720430108,0.6873725473880767,0.6349253362865859
|
||||
25,0.7529166666666667,0.769573252688172,0.6828897857666015,0.6308066611007977
|
||||
26,0.7547291666666667,0.7708333333333334,0.6747500312328338,0.6260690948655528
|
||||
27,0.7547291666666667,0.7722614247311828,0.6727884339491527,0.6230032514500362
|
||||
28,0.7580208333333334,0.7736055107526881,0.6705719958146413,0.6201616948650729
|
||||
29,0.757125,0.7748655913978495,0.6678872625033061,0.6165956912502166
|
||||
30,0.76025,0.7759576612903226,0.6641453323364258,0.6141596103227267
|
||||
31,0.7609583333333333,0.7767137096774194,0.6619020007451375,0.6114715427480718
|
||||
32,0.7618333333333334,0.776377688172043,0.6576849890549977,0.609050533463878
|
||||
33,0.7629583333333333,0.7783938172043011,0.6566993356545766,0.6069687149857962
|
||||
34,0.7631666666666667,0.7791498655913979,0.6553838003476461,0.6052348261238426
|
||||
35,0.7626666666666667,0.7785618279569892,0.6523663277626037,0.6037159392269709
|
||||
36,0.7653333333333333,0.780325940860215,0.6526273953914642,0.6023694485105494
|
||||
37,0.7639791666666667,0.7810819892473119,0.6524330813884736,0.6012957778669172
|
||||
38,0.7673958333333334,0.7799059139784946,0.6500039516290029,0.6002496516191831
|
||||
39,0.764875,0.7815860215053764,0.6476710465749105,0.5992967845291219
|
||||
40,0.7661041666666667,0.782258064516129,0.6484741485118866,0.5987250285763894
|
||||
41,0.7668333333333334,0.7818380376344086,0.64447620677948,0.5979425240588444
|
||||
42,0.7662291666666666,0.7817540322580645,0.6458657967249553,0.5976644568545844
|
||||
43,0.76675,0.7824260752688172,0.6449048202037811,0.5971484975789183
|
||||
44,0.7677916666666667,0.7825100806451613,0.6437487976551056,0.5968063899906733
|
||||
45,0.7677083333333333,0.7825940860215054,0.6432626353104909,0.5964810960395361
|
||||
46,0.7666041666666666,0.7826780913978495,0.6452286802132925,0.5963560240243071
|
||||
47,0.768875,0.7827620967741935,0.6432473388512929,0.5962550809947393
|
||||
48,0.7674583333333334,0.7826780913978495,0.6433716455300649,0.5961945524779699
|
||||
49,0.7678333333333334,0.7827620967741935,0.6440556212266286,0.596176440036425
|
||||
50,0.7677083333333333,0.7825940860215054,0.6438165396849315,0.5961719805835396
|
|
BIN
data/notebook2/random_qccnn_best.pt
Normal file
BIN
data/notebook2/random_qccnn_best.pt
Normal file
Binary file not shown.
51
data/notebook2/random_qccnn_metrics.csv
Normal file
51
data/notebook2/random_qccnn_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
||||
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||
1,0.6377291666666667,0.7426075268817204,1.0222952149709066,0.6658745436899124
|
||||
2,0.7588541666666667,0.7799059139784946,0.6436424539883931,0.5876969707909451
|
||||
3,0.7829791666666667,0.7955309139784946,0.5888768948713938,0.5508138095178912
|
||||
4,0.7927916666666667,0.8009912634408602,0.5641531114578247,0.5473136597423143
|
||||
5,0.7997291666666667,0.8065356182795699,0.5444657148520152,0.515406842834206
|
||||
6,0.8043541666666667,0.8140120967741935,0.5316140202681223,0.5052012169873843
|
||||
7,0.80975,0.8156922043010753,0.521293937365214,0.4983856466508681
|
||||
8,0.8155208333333334,0.8162802419354839,0.5079634555180867,0.49958501611986467
|
||||
9,0.8161875,0.8159442204301075,0.5036936206022898,0.4907228536503289
|
||||
10,0.816875,0.8160282258064516,0.497530157327652,0.4895895427914076
|
||||
11,0.8210416666666667,0.8269489247311828,0.4916797243754069,0.47517218673101036
|
||||
12,0.8210416666666667,0.8266129032258065,0.4892559293905894,0.468310099135163
|
||||
13,0.8241041666666666,0.8211525537634409,0.4825246704419454,0.4856182368852759
|
||||
14,0.8239791666666667,0.8264448924731183,0.48390908201535543,0.46895075997998636
|
||||
15,0.8245,0.8230846774193549,0.4795244421164195,0.47176380727880746
|
||||
16,0.8270416666666667,0.829133064516129,0.47462198424339297,0.4671747069205007
|
||||
17,0.8273958333333333,0.8297211021505376,0.4733834793567657,0.4658442559421703
|
||||
18,0.828125,0.8337533602150538,0.47195579528808596,0.46035799896845253
|
||||
19,0.8294583333333333,0.829133064516129,0.4668528276284536,0.46208705472689804
|
||||
20,0.830625,0.8303931451612904,0.46666145197550457,0.4616392642580053
|
||||
21,0.8320416666666667,0.8280409946236559,0.46406093080838523,0.4690516353935324
|
||||
22,0.8315833333333333,0.8293010752688172,0.4628266294002533,0.4590829178210228
|
||||
23,0.8315416666666666,0.8298051075268817,0.46202420779069264,0.4582436110383721
|
||||
24,0.8328541666666667,0.8282090053763441,0.4594616918563843,0.46520241454083433
|
||||
25,0.8321458333333334,0.8313172043010753,0.45899707794189454,0.4611524093535639
|
||||
26,0.835,0.8311491935483871,0.4561348853111267,0.460902286793596
|
||||
27,0.8342708333333333,0.8335013440860215,0.45571692689259846,0.4524733103731627
|
||||
28,0.83475,0.831989247311828,0.45314634958902994,0.4535403392648184
|
||||
29,0.8353958333333333,0.8333333333333334,0.4508692183494568,0.4509869323622796
|
||||
30,0.8358958333333333,0.8314012096774194,0.4508490650653839,0.4559444804345408
|
||||
31,0.8379791666666667,0.8347614247311828,0.44810916344324747,0.45132114586009775
|
||||
32,0.8369375,0.832997311827957,0.4488534228801727,0.4531173186917459
|
||||
33,0.8375208333333334,0.8366935483870968,0.4460577855904897,0.4518213977095901
|
||||
34,0.8372708333333333,0.8314852150537635,0.4453533016045888,0.4571616124081355
|
||||
35,0.8381458333333334,0.8351814516129032,0.4446527805328369,0.4513627043975297
|
||||
36,0.8392916666666667,0.8373655913978495,0.4431237142880758,0.44869983228304056
|
||||
37,0.838375,0.8347614247311828,0.4424218458334605,0.4486082660895522
|
||||
38,0.8397916666666667,0.8355174731182796,0.4412206415732702,0.44703892129723743
|
||||
39,0.8401875,0.8365255376344086,0.4396123299598694,0.4470836206149029
|
||||
40,0.8399791666666667,0.8345934139784946,0.43971687642733254,0.4466231196157394
|
||||
41,0.8410833333333333,0.8373655913978495,0.43810279977321626,0.4461495735312021
|
||||
42,0.8408958333333333,0.8372815860215054,0.4371747978528341,0.4456534408113008
|
||||
43,0.8414583333333333,0.8363575268817204,0.43774009283383686,0.4454879440287108
|
||||
44,0.841375,0.8358534946236559,0.4363077602386475,0.4460145914426414
|
||||
45,0.8408958333333333,0.8363575268817204,0.43665687982241314,0.44518788110825325
|
||||
46,0.841875,0.8371135752688172,0.43598757874965666,0.44520101848468985
|
||||
47,0.8422083333333333,0.836945564516129,0.4349166991710663,0.44520143988311933
|
||||
48,0.8421041666666667,0.8360215053763441,0.43520630804697674,0.44510498354511874
|
||||
49,0.8416875,0.836861559139785,0.4354708949327469,0.44491737311886204
|
||||
50,0.8408125,0.836861559139785,0.43524290529886883,0.4449239748139535
|
|
BIN
data/notebook2/vgg_best.pt
Normal file
BIN
data/notebook2/vgg_best.pt
Normal file
Binary file not shown.
51
data/notebook2/vgg_metrics.csv
Normal file
51
data/notebook2/vgg_metrics.csv
Normal file
@ -0,0 +1,51 @@
|
||||
epoch,train_acc,valid_acc,train_loss,valid_loss
|
||||
1,0.2535833333333333,0.41952284946236557,2.274448096593221,2.2001695325297694
|
||||
2,0.4691875,0.5073084677419355,2.116311640103658,2.051948335862929
|
||||
3,0.5270416666666666,0.5266297043010753,2.009436710357666,1.9777893276624783
|
||||
4,0.5353541666666667,0.535114247311828,1.9530798486073813,1.9356736265203005
|
||||
5,0.5935416666666666,0.6175235215053764,1.9118746633529664,1.891348486305565
|
||||
6,0.636625,0.6625504032258065,1.8629151741663614,1.847881397893352
|
||||
7,0.686625,0.7302587365591398,1.8228577140172322,1.8057919330494379
|
||||
8,0.7415416666666667,0.7511760752688172,1.7803055089314779,1.7647134245082896
|
||||
9,0.7542291666666666,0.7582325268817204,1.7502256110509236,1.7423059594246648
|
||||
10,0.7607708333333333,0.764616935483871,1.7317150996526083,1.724781782396378
|
||||
11,0.7677291666666667,0.7719254032258065,1.7192926041285197,1.713816902970755
|
||||
12,0.7704791666666667,0.7731014784946236,1.71117463239034,1.7070557173862253
|
||||
13,0.771625,0.7736055107526881,1.7053865385055542,1.7027398668309695
|
||||
14,0.7747708333333333,0.7749495967741935,1.7004834445317587,1.6978983327906618
|
||||
15,0.7763333333333333,0.7751176075268817,1.6967166414260864,1.6960316857983988
|
||||
16,0.7790416666666666,0.7767977150537635,1.6929608987172444,1.692371742699736
|
||||
17,0.780125,0.7801579301075269,1.6910814901987712,1.6917471321680213
|
||||
18,0.7817916666666667,0.7827620967741935,1.6882368882497152,1.6861866725388395
|
||||
19,0.7827291666666667,0.7815860215053764,1.6857908573150635,1.6868788478195027
|
||||
20,0.7827083333333333,0.7840221774193549,1.6843910643259685,1.6850647259784002
|
||||
21,0.7848541666666666,0.7857862903225806,1.6823169787724812,1.6813292862266622
|
||||
22,0.7876666666666666,0.7848622311827957,1.6801685171127319,1.6815461574062225
|
||||
23,0.7879791666666667,0.7848622311827957,1.6788572374979656,1.6818229216401295
|
||||
24,0.7890833333333334,0.7859543010752689,1.6776897859573365,1.6799169009731663
|
||||
25,0.7902708333333334,0.7883904569892473,1.6763870159784953,1.6777271570697907
|
||||
26,0.791,0.789986559139785,1.6751194190979004,1.6777111407249206
|
||||
27,0.7918541666666666,0.7886424731182796,1.674490571975708,1.6776156912567795
|
||||
28,0.7936666666666666,0.7905745967741935,1.6728278172810873,1.6744540147883917
|
||||
29,0.793375,0.7876344086021505,1.671903525352478,1.6776354030896259
|
||||
30,0.7945833333333333,0.7901545698924731,1.6710031661987306,1.6751063805754467
|
||||
31,0.7959166666666667,0.7932627688172043,1.6700964994430543,1.6719800926023913
|
||||
32,0.7955833333333333,0.7933467741935484,1.6697869466145834,1.6723043623790945
|
||||
33,0.7971458333333333,0.792002688172043,1.668815224647522,1.6722341519530102
|
||||
34,0.7974375,0.7931787634408602,1.6682115157445272,1.6723163063808153
|
||||
35,0.7979791666666667,0.7947748655913979,1.6676976483662924,1.670254216399244
|
||||
36,0.7984583333333334,0.7950268817204301,1.6669070380528768,1.6704239499184392
|
||||
37,0.799,0.795866935483871,1.6664897476832072,1.6691849411174815
|
||||
38,0.7991875,0.7948588709677419,1.6660230029424032,1.6700230785595473
|
||||
39,0.800125,0.795950940860215,1.6655609397888183,1.6688298884258475
|
||||
40,0.8006458333333333,0.7956989247311828,1.6652250595092772,1.6690674840763051
|
||||
41,0.8006875,0.7962869623655914,1.664971160888672,1.6688041135828982
|
||||
42,0.8015416666666667,0.795866935483871,1.6647087678909303,1.668453808753721
|
||||
43,0.8013958333333333,0.7962029569892473,1.6643381754557292,1.6682484457569737
|
||||
44,0.8016458333333333,0.7965389784946236,1.6641663211186728,1.668392535178892
|
||||
45,0.8017708333333333,0.7957829301075269,1.6640429185231527,1.6681856429705055
|
||||
46,0.8021041666666666,0.7965389784946236,1.663849822362264,1.6680005634984663
|
||||
47,0.8024583333333334,0.7966229838709677,1.663770879427592,1.6679569816076627
|
||||
48,0.8025833333333333,0.7962029569892473,1.6636734215418498,1.6680130330465173
|
||||
49,0.8024583333333334,0.7970430107526881,1.6635978577931723,1.6679544077124646
|
||||
50,0.8025,0.7970430107526881,1.6635685895284016,1.6679527682642783
|
|
Loading…
Reference in New Issue
Block a user