- 获取链接
- X
- 电子邮件
- 其他应用
- 获取链接
- X
- 电子邮件
- 其他应用
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义变分自编码器模型
# Define VAE
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, latent_dim)
self.fc22 = nn.Linear(hidden_dim, latent_dim)
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
#重参数方法
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# 定义损失函数
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
The first term is reconstruction loss, and if we assume dst=1, the term can be simplified to BCE, usually binary_cross_entropy or MSEloss.
# 加载数据
batch_size = 128
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_loader = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)
# 初始化模型
input_dim = 784
hidden_dim = 400
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
# 训练模型
num_epochs = 10
vae.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = vae(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
total_loss += loss.item()
optimizer.step()
print('Epoch %d, Loss: %.4f' % (epoch+1, total_loss / len(train_loader.dataset)))
# 使用训练好的模型生成图像
vae.eval()
with torch.no_grad():
sample = torch.randn(64, latent_dim).to(device)
sample = vae.decode(sample).cpu()
评论
发表评论