Skip to content

Commit

Permalink
Added automatic model saving and loading
Browse files Browse the repository at this point in the history
Added automatic model saving and loading
  • Loading branch information
52hearts3 authored Oct 6, 2024
1 parent 01150ec commit e799fd1
Showing 1 changed file with 320 additions and 0 deletions.
320 changes: 320 additions & 0 deletions BigGAN-pic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
import torch
from torch import nn,optim,autograd
from torchvision import datasets ,transforms
from torch.utils.data import DataLoader

class condition_BatchNorm2d(nn.Module):
def __init__(self,num_features,num_classes): #num_features 通道数 num_classes 类别数
super(condition_BatchNorm2d,self).__init__()

self.num_features=num_features
self.bn=nn.BatchNorm2d(num_features,affine=False)
self.embed=nn.Embedding(num_classes,num_features*2) # 创建一个包含num_classes个向量,每个向量维度为num_features的嵌入层
self.embed.weight.data[:,:num_features].normal_(1,0.02) #将嵌入矩阵的前 num_features 列的值用均值为1、标准差为0.02的正态分布进行初始化
self.embed.weight.data[:,num_features:].zero_() #将嵌入层权重矩阵中从 num_features 列开始的所有列的值初始化为零

def forward(self,x,y):
out=self.bn(x)
gamma,beta=self.embed(y).chunk(2,1) #在dim=1上一分为二, [num_classes,num_features*2]==>[num_classes,num_features]
gamma=gamma.view(-1,self.num_features,1,1)
beta=beta.view(-1,self.num_features,1,1)
out=gamma*out+beta #[b,ch,1,1]*[b,ch,x,x]==>[b,ch,x,x] 广播
return out


class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out,stride=1,num_classes=0):
super(ResBlk,self).__init__()

self.bn1 = condition_BatchNorm2d(num_features=ch_in, num_classes=num_classes)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn2 = condition_BatchNorm2d(num_features=ch_out, num_classes=num_classes)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)

self.shortcut=nn.Conv2d(ch_in,ch_out,stride=stride,kernel_size=1,padding=0)

def forward(self,x,labels):
labels=labels.long().view(-1)
out = self.bn1(x, labels)
out = self.relu1(out)
out = self.conv1(out)
out = self.bn2(out, labels)
out = self.relu2(out)
out = self.conv2(out)
output=out+self.shortcut(x)
return output

#test
# test=torch.randn(32,4,128,128)
# labels = torch.randint(0, 10, (32,))
# print(labels.size())
# model=ResBlk(4,64,stride=2,num_classes=10)
# print(model(test,labels).shape)


def truncated_noise_sample(batch_size, z_dim, truncation=0.5): #截断函数
"""
生成截断的潜在向量。
参数:
- batch_size: 批量大小
- z_dim: 潜在向量的维度
- truncation: 截断阈值
返回:
- 截断的潜在向量
"""
noise=torch.randn(batch_size,z_dim)
truncated_noise = torch.clamp(noise, -truncation, truncation) #将noise限制在[-0.5,0.5]
return truncated_noise

#test
# batch_size = 16
# z_dim = 128
# truncation = 0.5
# truncated_z = truncated_noise_sample(batch_size, z_dim, truncation)
# print('截断的潜在向量 :',truncated_z.size())

class Generator(nn.Module): #在使用生成器前先调用截断函数
def __init__(self,z_dim,g_dim,image_size,num_classes):
super(Generator,self).__init__()

self.z_dim=z_dim
self.g_dim=g_dim
self.image_size=image_size
self.init_size=image_size//2

self.linear_1=nn.Sequential(
nn.Linear(z_dim , g_dim * 8 * self.init_size**2)
)

self.res_blocks=nn.Sequential(
ResBlk(g_dim * 8, g_dim* 8 ,stride=1,num_classes=num_classes),
ResBlk(g_dim * 8 ,g_dim * 4,stride=1,num_classes=num_classes),
ResBlk(g_dim * 4, g_dim * 2,stride=1,num_classes=num_classes),
ResBlk(g_dim * 2, g_dim,stride=1,num_classes=num_classes)
)

self.up_sample_1=nn.Sequential(
nn.Upsample(scale_factor=2), #上采样。[b,ch,x,x]==>[b,ch,2x,2x]
nn.Conv2d(g_dim, g_dim, kernel_size=3, stride=1, padding=1)
)
self.up_sample_2=condition_BatchNorm2d(num_classes=num_classes,num_features=g_dim)
self.relu=nn.ReLU(True)

# self.up_sample=nn.Sequential(
# nn.Upsample(scale_factor=2),
# nn.Conv2d(g_dim,g_dim,kernel_size=3,stride=1,padding=1),
# condition_BatchNorm2d(num_classes=num_classes,num_features=g_dim),
# nn.ReLU(True)
# )

self.final_layer=nn.Sequential(
nn.Conv2d(g_dim ,3,kernel_size=3,stride=1,padding=1),
nn.Tanh()
)

def forward(self,z,labels):
#print(z.size())
out=self.linear_1(z)
out=out.view(-1, self.g_dim * 8 , self.init_size, self.init_size)
for block in self.res_blocks:
out = block(out, labels)
out=self.up_sample_1(out)
out=self.up_sample_2(out,labels)
out=self.relu(out)
out=self.final_layer(out)
return out

#test
# z_dim = 100
# g_dim = 64
# image_size = 128
# batch_size=12
# truncation = 0.5
# truncated_z = truncated_noise_sample(batch_size, z_dim, truncation)
# labels = torch.randint(0, 10, (12,))
# G = Generator(z_dim, g_dim, image_size,num_classes=10)
# print(G(truncated_z,labels).size())

class Discriminator(nn.Module):
def __init__(self,d_dim,image_size):
super(Discriminator,self).__init__()

self.d_dim=d_dim
self.image_size=image_size

self.conv=nn.Sequential(
nn.Conv2d(3,d_dim,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(d_dim),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(d_dim,d_dim*2,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(d_dim*2),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(d_dim*2,d_dim*4,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(d_dim*4),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(d_dim*4,d_dim*8,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(d_dim*8),
nn.LeakyReLU(0.2,inplace=True)
)

self.down_sample=nn.Sequential(
nn.AdaptiveAvgPool2d(1), #[b,ch,x,x]==>[b,ch,1,1] 全局平均池化层
nn.BatchNorm2d(d_dim*8),
nn.LeakyReLU(0.2,inplace=True)
)

self.final_layer=nn.Sequential(
nn.Conv2d(d_dim*8,1,kernel_size=4,stride=1,padding=0),
)


self.linear=nn.Sequential(
nn.Linear(d_dim*8,1),
)

def forward(self,x):
out=self.conv(x)
#print(out.size())
out=self.down_sample(out)
#print(out.size())
out=out.view(out.size(0),-1)
out=self.linear(out)
out=out.view(-1)
return out

#test
# test=torch.randn(12,3,224,224)
# image_size = 224
# d_dim = 64
# D = Discriminator(image_size=image_size, d_dim=d_dim)
# print(D(test).size())
def gradient_penalty(D,x_real,x_fake,batch_size):
#[b,1,1,1]
t=torch.rand(batch_size,1,1,1,device=x_real.device)
#[b,3,32,32]
t=t.expand_as(x_real)

mid=t*x_real+((1-t)*x_fake)
mid.requires_grad_()
pred=D(mid)
grads=autograd.grad(outputs=pred,inputs=mid,
grad_outputs=torch.ones_like(pred),
create_graph=True,retain_graph=True,only_inputs=True)[0] #对mid求导
gradient=grads.view(grads.size(0),-1)
gp=((gradient.norm(2, dim=1) - 1) ** 2).mean() #grads.norm(2,dim=1)求l2范数
return gp


tf=transforms.Compose([
transforms.Resize((224,224)),
#transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data=datasets.ImageFolder(root=r'D:\game\pytorch\GAN\GAN实战-GD实现\BigGAN\pic',transform=tf)
loader=DataLoader(data,shuffle=True,batch_size=16)

z_dim = 128
g_dim = 32
image_size = 224
num_classes = 1
generator=Generator(z_dim=z_dim,g_dim=g_dim,image_size=image_size,num_classes=num_classes).to(device)
discriminator=Discriminator(d_dim=g_dim,image_size=image_size).to(device)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4) #加入了正则化
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)

import matplotlib.pyplot as plt


def show_generated_images(epoch, generator, labels, num_images=5):
if torch.cuda.is_available():
z = torch.randn(num_images, 128).cuda()
else:
z = torch.randn(num_images, 128)

# 确保labels的尺寸与z的批量大小一致
labels = labels[:num_images]
#labels = labels[:1].repeat(num_images)

fake_images = generator(z, labels).cpu().detach()
fake_images = (fake_images + 1) / 2 # 将图像从 [-1, 1] 转换到 [0, 1]

fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
for i in range(num_images):
axes[i].imshow(fake_images[i].permute(1, 2, 0).numpy())
axes[i].axis('off')
plt.suptitle(f'Epoch {epoch + 1}')
plt.show()

# 定义保存和加载模型的函数
def save_checkpoint(generator, discriminator, optimizer_G, optimizer_D, epoch, path='checkpoint.pth'):
torch.save({
'epoch': epoch,
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'optimizer_G_state_dict': optimizer_G.state_dict(),
'optimizer_D_state_dict': optimizer_D.state_dict(),
}, path)

import os
def load_checkpoint(generator, discriminator, optimizer_G, optimizer_D, path='checkpoint.pth'):
if os.path.isfile(path):
checkpoint = torch.load(path)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"Checkpoint loaded, starting from epoch {start_epoch}")
return start_epoch
else:
print("No checkpoint found, starting from scratch")
return 0

# 加载检查点
start_epoch = load_checkpoint(generator, discriminator, optimizer_G, optimizer_D)

for epoch in range(start_epoch,1000):
for x,y in loader:
x_real,y_real=x.to(device),y.to(device)
batch_size=x_real.size(0)
for p in discriminator.parameters():
p.requires_grad = True

#训练判别器
for _ in range(3):
optimizer_D.zero_grad()
z=truncated_noise_sample(batch_size=batch_size,z_dim=z_dim,truncation=0.5).to(device)
fake_x=generator(z,y_real)

real_loss=-torch.mean(discriminator(x_real))
fake_loss=torch.mean(discriminator(fake_x.detach())) #这里不对fake_x进行反向传播

gp=gradient_penalty(discriminator,x_real=x_real,x_fake=fake_x,batch_size=batch_size)

d_loss=real_loss+fake_loss+10*gp
d_loss.backward()
optimizer_D.step()

for p in discriminator.parameters():
p.requires_grad=False
#训练生成器
optimizer_G.zero_grad()
z = truncated_noise_sample(batch_size=batch_size, z_dim=z_dim, truncation=0.5).to(device)
fake_x = generator(z, y_real)
g_loss=-torch.mean(discriminator(fake_x))
g_loss.backward()
optimizer_G.step()
#print(y_real)
print(f"Epoch [{epoch + 1}/{1000}] Loss D: {d_loss.item()}, loss G: {g_loss.item()}")
if (epoch+1)%3==0:
save_checkpoint(generator, discriminator, optimizer_G, optimizer_D, epoch)
if (epoch+1)%100==0:
show_generated_images(epoch + 1, generator, labels=y_real)

0 comments on commit e799fd1

Please sign in to comment.