Realize VAE with Pytorch. 变分自编码器代码简单实现


 

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader


# 定义变分自编码器模型

          # Define VAE

class VAE(nn.Module):

    def __init__(self, input_dim, hidden_dim, latent_dim):

        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)

        self.fc21 = nn.Linear(hidden_dim, latent_dim)

        self.fc22 = nn.Linear(hidden_dim, latent_dim)

        self.fc3 = nn.Linear(latent_dim, hidden_dim)

        self.fc4 = nn.Linear(hidden_dim, input_dim)


    def encode(self, x):

        h1 = F.relu(self.fc1(x))

        return self.fc21(h1), self.fc22(h1)

#重参数方法

    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5*logvar)

        eps = torch.randn_like(std)

        return mu + eps*std


    def decode(self, z):

        h3 = F.relu(self.fc3(z))

        return torch.sigmoid(self.fc4(h3))


    def forward(self, x):

        mu, logvar = self.encode(x.view(-1, 784))

        z = self.reparameterize(mu, logvar)

        return self.decode(z), mu, logvar


# 定义损失函数

def loss_function(recon_x, x, mu, logvar):

    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

The first term is reconstruction loss, and if we assume dst=1, the term can be simplified to BCE, usually binary_cross_entropy or MSEloss.

 

# 加载数据

batch_size = 128

transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])

train_loader = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)


# 初始化模型

input_dim = 784

hidden_dim = 400

latent_dim = 20

vae = VAE(input_dim, hidden_dim, latent_dim)

optimizer = optim.Adam(vae.parameters(), lr=1e-3)


# 训练模型

num_epochs = 10

vae.train()

for epoch in range(num_epochs):

    total_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):

        data = data.to(device)

        optimizer.zero_grad()

        recon_batch, mu, logvar = vae(data)

        loss = loss_function(recon_batch, data, mu, logvar)

        loss.backward()

        total_loss += loss.item()

        optimizer.step()

    print('Epoch %d, Loss: %.4f' % (epoch+1, total_loss / len(train_loader.dataset)))


# 使用训练好的模型生成图像

vae.eval()

with torch.no_grad():

    sample = torch.randn(64, latent_dim).to(device)

    sample = vae.decode(sample).cpu()


评论