因为之前用生成对抗网络及众多变体生成诸如心电信号,肌电信号,脑电信号,微震信号,机械振动信号,雷达信号等,但生成的信号在频谱或者时频谱上表现很差,所以暂时先不涉及到这些复杂信号,仅仅以手写数字图像为例进行说明,因为Python相关的资源太多了,我就不凑热闹了,使用的编程环境为MALAB R2021B。
首先看一下对抗自编码器AAE(Adversarial AutoEncoder),关于AAE的大致理解,可以查看如下文章
AAE(Adversarial Autoencoders)浅解 - 嘎嘎小鱼仔的文章 - 知乎 https://zhuanlan.zhihu.com/p/382958740
AAE根据变分自编码器VAE发展而来,其发展之处就在于加入了对抗的思想。
上半部分就是一个简单典型的自编码器AE结构,包含输入层input layer,编码层encoder layer, 隐层hidden layer, 解码层decoder layer , 输出层output layer。encoder把真实分布x映射为隐层z, decoder 再将z解码还原成x。AAE的特点就在于在隐层hidden layer中引入了对抗的思想来优化隐层的z,判别器discriminator 需要在隐层判断采样后的真实数据和生成器encoder所产生的假数据。因此discriminator的目的就是使得q(z | x) 不断向p(z)靠近。
Adversarial Autoencoders论文链接:https://arxiv.org/abs/1511.0564
下面直接上代码
首先,导入相关的mnist手写数字图
load('mnistAll.mat')
然后对训练、测试图像进行预处理
trainX = preprocess(mnist.train_images); trainY = mnist.train_labels;%训练标签testX = preprocess(mnist.test_images); testY = mnist.test_labels;%测试标签
preprocess为归一化函数,如下
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等
settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;
下面进行编码器初始化,代码还是很容易看懂的
paramsEn.FCW1 = dlarray(initializeGaussian([512,...
prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));
解码器初始化
paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));
判别器初始化
paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));
%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];
开始训练
dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
tic;
shuffleid = randperm(size(trainX,2));
trainXshuffle = trainX(:,shuffleid);
fprintf('Epoch %d
',epoch)
for i=1:numIterations
global_iter = global_iter+1;
idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');
[GradEn,GradDe,GradDis] = ...
dlfeval(@modelGradients,XBatch,...
paramsEn,paramsDe,paramsDis,settings);
% 更新判别器网络参数
[paramsDis,avgG.Dis,avgGS.Dis] = ...
adamupdate(paramsDis, GradDis, ...
avgG.Dis, avgGS.Dis, global_iter, ...
settings.lrD, settings.beta1, settings.beta2);
% 更新编码器网络参数
[paramsEn,avgG.En,avgGS.En] = ...
adamupdate(paramsEn, GradEn, ...
avgG.En, avgGS.En, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
% 更新解码器网络参数
[paramsDe,avgG.De,avgGS.De] = ...
adamupdate(paramsDe, GradDe, ...
avgG.De, avgGS.De, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
if i==1 || rem(i,20)==0
progressplot(paramsDe,settings);
if i==1
h = gcf;
% 捕获图像
frame = getframe(h);
im = frame2im(frame);
[imind,cm] = rgb2ind(im,256);
% 写入 GIF 文件
if epoch == 0
imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf);
else
imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append');
end
end
end
end
elapsedTime = toc;
disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
epoch = epoch+1;
if epoch == settings.maxepochs
out = true;
end
end
下面是完整的辅助函数
模型的梯度计算函数
function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
dly(settings.latent_dim+1:2*settings.latent_dim)*...
randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');
%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));
%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));
%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
提取数据函数
function x = gatext(x)
x = gather(extractdata(x));
end
GPU深度学习数组wrapper函数
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
权重初始化函数
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
dropout函数
function dly = dropout(dlx,p)
if nargin < 2
p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end
编码器函数
function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end
解码器函数
function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end
判别器函数
function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end
动态进度图
function progressplot(paramsDe,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Decoder(noise,paramsDe);
gen_imgs = reshape(gen_imgs,28,28,[]);
fig = gcf;
if ~isempty(fig.Children)
delete(fig.Children)
end
I = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap gray
drawnow;
end
最后,看一下生成的GIF动态图
以后会讲
(1)辅助分类器生成对抗网络Auxiliary Classifier Generative Adversarial Network
(2)条件生成对抗网络Conditional Generative Adversarial Network
(3)深层卷积生成对抗网络Deep Convolutional Generative Adversarial Network
(4)最基础的生成对抗网络Basic Generative Adversarial Network
(5)Info Generative Adversarial Network
(6)最小二乘生成对抗网络Least Squares Generative Adversarial Network
(7)著名的Pixels-to-Pixels
(8)半监督生成对抗网络Semi-Supervised Generative Adversarial Network
(9)著名的Wasserstein Generative Adversarial Network
相应的参考文献如下
留言与评论(共有 0 条评论) “” |