作者: | 来源:互联网 | 2023-09-06 15:13
构建生成器模块
from torch import nn
import torch
import numpy as np
latent_dim = 100
img_shape = (1,28,28)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
构造生成器对象
generator = Generator()
generator.model
Sequential(
(0): Linear(in_features=100, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Linear(in_features=128, out_features=256, bias=True)
(3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Linear(in_features=256, out_features=512, bias=True)
(6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Linear(in_features=512, out_features=1024, bias=True)
(9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Linear(in_features=1024, out_features=784, bias=True)
(12): Tanh()
)