莫度编程网

技术文章干货、编程学习教程与开发工具分享

基于分数的生成模型之斯坦因(Stein)分数函数

1、(斯坦因)分数函数 (Stein) score function

1)朗之万动力学公式的第二部分就是梯度x log p(x)。它有个正式的名称,即斯坦因分数函数,也就是

2)应当注意不要将斯坦因分数函数和一般的分数函数混淆了,一般的分数函数定义为:

3)一般的分数函数是对数似然的梯度(关于θ)。相反,斯坦因分数函数式关于数据点x的梯度。最大似然估计使用的就是分数函数,但是朗之万动力学使用斯坦因分数函数。但是,由于绝大多数扩散模型文献中就把斯坦因分数函数简称为分数函数,朗之万动力学中的“分数函数”确切来说应该是斯坦因分数函数。

4)理解分数函数的方法就是记住这个关于数据x的梯度。对于任意的高维分布p(x),由梯度可以得到一个向量场:

5)如果p(x)服从高斯分布,且有p(x)=

如果p(x)服从高斯混合分布,且有p(x)=

对于高斯分布和混合高斯分布的概率密度函数和相应的分数函数如下图所示:

6)分数函数的几何解释

o 在 log p(x) 变化最大的地方,矢量的振幅最强。因此,在log p(x)接近峰值的区域将大多是非常弱的梯度。

o 向量场表明了一个数据点在等值线图中应该如何运动。在下图中可见一个高斯混合(含有两个高斯分布)的等值线图。用箭头来表示向量场,如果考虑一个数据点在空间上,朗之万动力学方程基本上将沿着向量场所指向的方向引导数据点向“盆地”移动。

o 从物理的角度看,分数函数等价于“漂移”,也就是扩散粒子应该如何流动到最低能量状态。

2、分数匹配

朗之万动力学可以直接通过采样生成图像,但是需要知道分布p(x),分数匹配就是求解分数函数,有了分数函数,使用分数函数加随机噪声采样即可生成图像

1)朗之万动力学中最难的问题就是如何得到xp(x),因为我们求不到p(x)。(斯坦因)分数函数的定义:

其中用下标θ表明sθ将通过一个神经网络实现。由于等式右边是未知的,需要一个方便的方法来估计它,简要讨论两种估计。

2)显式分数匹配. 假设给定一个数据集X = {x1,. . . , xM}。人们想到的办法就是通过定义一个分布来考虑经典的核密度估计:

其中h是核函数K(·)的超参,xm是训练集中的第m个样本.下图展示了核密度估计的思想,在左图,展示以不同数据点xm为中心的多个核K(·)。所有这些单核的总和给出了总体的核密度估计q(x)。在右图,展示了一个真实的直方图和相应的核密度估计。我们认为q(x)至多是实际数据分布p(x)的一个近似估计,而实际的数据分布p(x)是不可得的。

由于q(x)是对不可求得的p(x)的一个估计,可以通过q(x)来学习sθ(x)。这就引出下面的损失函数的定义,这个损失函数可以被用来训练一个神经网络,显式分数匹配损失为

是一个典型的L2损失,将核密度估计代入,可以得到:

推导得出了一个可以用于训练的损失函数。只要训练好网络sθ,就可以代入朗之万动力学公式得到递归公式进行图像生成:

显式得分匹配的问题在于核密度估计是对真实分布的一个相当不好的非参数估计。特别是当样本数有限且样本处于高维空间时,核密度估计的计算量非常大,性能会很差

3)去噪声分数匹配. 考虑到显式分数匹配的潜在缺点,引入一个更加流行的分数匹配方法:去噪声分数匹配(DSM)。在DSM中,损失函数被定义为:

主要的不同就是用一个条件分布q(x|x′)代替了分布q(x)。后者需要一个估计,像是核密度估计才行,但是前者不需要。

用高斯分布特殊化一下,令q(x|x′)= N(x | x′,σ2),让x = x′ + σz,可以得到:

因此,去噪声分数匹配的损失函数就变成了:

如果用x替换虚变量x′,在给定训练数据集的前提下,q(x)的采样可以被p(x)的采样代替,去噪声分数匹配的损失函数定义为:

上式它具有高度的可解释性。x + σz是在原图像x中有效地添加噪声σz。分数函数sθ应该取这个噪声图像,并预测噪声 σ/z2 。预测噪声等价于去噪声,因为任何去噪声后的图像加上预测噪声都会得到噪声图像,因此,上式是去噪声步骤。下图说明了分数函数sθ(x)的训练过程。

训练步骤可以形容为:给定一个训练集{x(l)},我们要训练一个网络θ,目标是:

对于推理,假设已经训练好了分数估计器sθ,要生成一个图像,就只要对t =1,. . . , T进行:




控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言

    Powered By Z-BlogPHP 1.7.4

    蜀ICP备2024111239号-43