博客
关于我
机器学习-白板推导系列笔记(十二)-变分推断(VI)
阅读量:317 次
发布时间:2019-03-04

本文共 15833 字,大约阅读时间需要 52 分钟。

此文章主要是结合哔站shuhuai008大佬的白板推导视频:

全部笔记的汇总贴:

一、背景

对于概率模型

  • 从频率派角度来看就会是一个优化问题
  • 从贝叶斯角度来看就会是一个积分问题

从贝叶斯来看

p ( x ^ ∣ x ) = ∫ θ p ( x ^ , θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ , x ) p ( θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ x ) d θ = E θ ∣ x [ p ( x ^ ∣ θ ) ] p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}_\theta \\=\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}_\theta \\ \overset{}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}_\theta \\=E_{\theta |x}[p(\hat{x}|\theta )] p(x^x)=θp(x^,θx)dθ=θp(x^θ,x)p(θx)dθ=θp(x^θ)p(θx)dθ=Eθx[p(x^θ)]

Inference分为:

  • 精确推断
  • 近似推断(确定性近似—VI;随机近似—MCMC、MH、Gibbs)

优化问题分为:

  • 回归 model: f ( w ) = w T x f(w)=w^Tx f(w)=wTx

loss-function:无约束
L ( w ) = ∑ i = 1 N ∣ ∣ w T x i − y i ∣ ∣ 2 L(w)=\sum^{N}_{i=1}||w^Tx_i-y_i||^2 L(w)=i=1NwTxiyi2
w ^ = arg ⁡ min ⁡ L ( w ) \hat{w}=\arg\min L(w) w^=argminL(w)
解法:
1.解析解:求导令为 0 0 0,得 w ∗ = ( X T X ) − 1 X T Y w^*=(X^TX)^{-1}X^TY w=(XTX)1XTY
2.数值解:GD、SGD

  • SVM(分类)

f ( w ) = s i g n ( w T x + b ) f(w)=sign(w^Tx+b) f(w)=sign(wTx+b)
loss-function:有约束
min ⁡ 1 2 w T w \min\frac{1}{2}w^Tw min21wTw
s . t .   y i ( w T x i + b ) ≥ 1 , i = 1 , 2 , ⋯   , N s.t. \ y_i(w^Tx_i+b)\geq 1,i=1,2,\cdots,N s.t. yi(wTxi+b)1,i=1,2,,N
connex优化 对偶

  • EM

θ ^ = arg ⁡ max ⁡ log ⁡ p ( x ∣ θ ) \hat{\theta}=\arg\max\log p(x|\theta) θ^=argmaxlogp(xθ)
θ ( t + 1 ) = arg max ⁡ ∫ p ( x , z ∣ θ ) ⋅ p ( z ∣ x , θ ( t ) ) d z θ \theta^{(t+1)}=\underset{\theta}{\argmax\int p(x,z|\theta)\cdot p(z|x,\theta^{(t)}){d}z} θ(t+1)=θargmaxp(x,zθ)p(zx,θ(t))dz

二、公式

Data:

x x x:observed variable → X : { x i } i = 1 N \rightarrow X:\left \{x_{i}\right \}_{i=1}^{N} X:{ xi}i=1N
z z z:latent variable + parameter → Z : { z i } i = 1 N \rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N} Z:{ zi}i=1N
( X , Z ) (X,Z) (X,Z):complete data

引入分布 q ( z ) q(z) q(z)

l o g    p ( x ) = l o g    p ( x , z ) − l o g    p ( z ∣ x ) = l o g    p ( x , z ) q ( z ) − l o g    p ( z ∣ x ) q ( z ) log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)} logp(x)=logp(x,z)logp(zx)=logq(z)p(x,z)logq(z)p(zx)

式子两边同时对 q ( z ) q(z) q(z)求积分:

左边 = ∫ z q ( z ) ⋅ l o g    p ( x ∣ θ ) d z = l o g    p ( x ∣ θ ) ∫ z q ( z ) d z = l o g    p ( x ∣ θ ) =\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta ) =zq(z)logp(xθ)dz=logp(xθ)zq(z)dz=logp(xθ)
右边 = ∫ z q ( z ) l o g    p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( E v i d e n c e    L o w e r    B o u n d ) − ∫ z q ( z ) l o g    p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 =\underset{ELBO(Evidence\; Lower\; Bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}} =ELBO(EvidenceLowerBound) zq(z)logq(z)p(x,zθ)dzKL(q(z)p(zx,θ)) zq(z)logq(z)p(zx,θ)dz= L(q)+0 KL(qp)

q q q p p p相等时, K L ( q ∣ ∣ p ) KL(q||p) KL(qp)等于 0 0 0,此时 K L ( q ∣ ∣ p ) KL(q||p) KL(qp)取值最小,所以这时就是要使 L ( q ) L(q) L(q)越大越好:

q ~ ( z ) = a r g m a x q ( z )    L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) \tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x) q~(z)=q(z)argmaxL(q)q~(z)p(zx)

我们对 q ( z q(z q(z)做以下假设,将多维变量的不同维度分为 M M M组,组与组之间而且是相互独立的,所以:

q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_{i}(z_{i}) q(z)=i=1Mqi(zi)

此时我们固定 q i ( z i ) , i ≠ j q_{i}(z_{i}),i\neq j qi(zi),i=j来求 q j ( z j ) q_{j}(z_{j}) qj(zj),所以:

L ( q ) = ∫ z q ( z ) l o g    p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g    q ( z ) d z ⏟ ② L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}} L(q)= zq(z)logp(x,z)dz zq(z)logq(z)dz

对于 ① ①

① = ∫ z ∏ i = 1 M q i ( z i ) l o g    p ( x , z ) d z 1 d z 2 ⋯ d z M = ∫ z j q j ( z j ) ( ∫ z − z j ∏ i ≠ j M q i ( z i ) l o g    p ( x , z ) d z 1 d z 2 ⋯ d z M ( i ≠ j ) ) ⏟ ∫ z − z j l o g    p ( x , z ) ∏ i ≠ j M q i ( z i ) d z i d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g    p ( x , z ) ] ⋅ d z j = ∫ z j q j ( z j ) ⋅ l o g    p ^ ( x , z j ) ⋅ d z j ①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot log\; \hat{p}(x,z_{j})\cdot \mathrm{d}z_{j} =zi=1Mqi(zi)logp(x,z)dz1dz2dzM=zjqj(zj)zzjlogp(x,z)i=jMqi(zi)dzi zzji=jMqi(zi)logp(x,z)(i=j)dz1dz2dzMdzj=zjqj(zj)Ei=jMqi(zi)[logp(x,z)]dzj=zjqj(zj)logp^(x,zj)dzj

对于 ② ②

② = ∫ z q ( z ) l o g    q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M l o g    q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ l o g    q 1 ( z 1 ) + l o g    q 2 ( z 2 ) + ⋯ + l o g    q M ( z M ) ] d z ②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z =zq(z)logq(z)dz=zi=1Mqi(zi)i=1Mlogqi(zi)dz=zi=1Mqi(zi)[logq1(z1)+logq2(z2)++logqM(zM)]dz

其中

∫ z ∏ i = 1 M q i ( z i ) l o g    q 1 ( z 1 ) d z = ∫ z 1 z 2 ⋯ z M q 1 ( z 1 ) q 2 ( z 2 ) ⋯ q M ( z M ) ⋅ l o g    q 1 ( z 1 ) d z 1 d z 2 ⋯ d z M = ∫ z 1 q 1 ( z 1 ) l o g    q 1 ( z 1 ) d z 1 ⋅ ∫ z 2 q 2 ( z 2 ) d z 2 ⏟ = 1 ⋅ ∫ z 3 q 3 ( z 3 ) d z 3 ⏟ = 1 ⋯ ∫ z M q M ( z M ) d z M ⏟ = 1 = ∫ z 1 q 1 ( z 1 ) l o g    q 1 ( z 1 ) d z 1 \int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1} zi=1Mqi(zi)logq1(z1)dz=z1z2zMq1(z1)q2(z2)qM(zM)logq1(z1)dz1dz2dzM=z1q1(z1)logq1(z1)dz1=1 z2q2(z2)dz2=1 z3q3(z3)dz3=1 zMqM(zM)dzM=z1q1(z1)logq1(z1)dz1

也就是说

∫ z ∏ i = 1 M q i ( z i ) l o g    q k ( z k ) d z = ∫ z k q k ( z k ) l o g    q k ( z k ) d z k \int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k} zi=1Mqi(zi)logqk(zk)dz=zkqk(zk)logqk(zk)dzk

② = ∑ i = 1 M ∫ z i q i ( z i ) l o g    q i ( z i ) d z i = ∫ z j q j ( z j ) l o g    q j ( z j ) d z j + C ②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C =i=1Mziqi(zi)logqi(zi)dzi=zjqj(zj)logqj(zj)dzj+C

① − ②    ①-②\;

① − ② = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j + C ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ( z j ) ∣ ∣ p ^ ( x , z j ) ) ≤ 0 ①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0 =zjqj(zj)logqj(zj)p^(x,zj)dzj+Czjqj(zj)logqj(zj)p^(x,zj)dzj=KL(qj(zj)p^(x,zj))0

q j ( z j ) = p ^ ( x , z j ) q_{j}(z_{j})=\hat{p}(x,z_{j}) qj(zj)=p^(x,zj)才能得到最⼤值。

三、联系EM算法

在广义EM算法中,我们需要首先固定 θ \theta θ,然后求与 p p p最接近的 q q q,这里就可以使用变分推断的方法:

l o g    p θ ( x ) = E L B O ⏟ L ( q ) + K L ( q ∣ ∣ p ) ⏟ ≥ 0 ≥ L ( q ) log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q) logpθ(x)=L(q) ELBO+0 KL(qp)L(q)

目标函数:

q ^ = a r g m i n q    K L ( q ∣ ∣ p ) = a r g m a x q    L ( q ) \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q) q^=qargminKL(qp)=qargmaxL(q)

l o g    q j ( z j ) = E ∏ i ≠ j m q i ( z i ) [ l o g    p θ ( x , z ) ] = ∫ z 1 ∫ z 2 ⋯ ∫ z j − 1 ∫ z j + 1 ⋯ ∫ z m q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q m ⋅ l o g    p θ ( x , z ) d z 1 d z 2 ⋯ d z j − 1 d z j + 1 ⋯ d z m log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{m}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{m}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{m} logqj(zj)=Ei=jmqi(zi)[logpθ(x,z)]=z1z2zj1zj+1zmq1q2qj1qj+1qmlogpθ(x,z)dz1dz2dzj1dzj+1dzm

l o g    q ^ 1 ( z 1 ) = ∫ z 2 ⋯ ∫ z m q 2 ⋯ q m ⋅ l o g    p θ ( x , z ) d z 2 ⋯ d z m l o g    q ^ 2 ( z 2 ) = ∫ z 1 ∫ z 3 ⋯ ∫ z m q ^ 1 q 3 ⋯ q m ⋅ l o g    p θ ( x , z ) d z 1 d z 3 ⋯ d z m ⋮ l o g    q ^ m ( z m ) = ∫ z 1 ⋯ ∫ z m − 1 q ^ 1 ⋯ q ^ m − 1 ⋅ l o g    p θ ( x , z ) d z 1 ⋯ d z m − 1 log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{m}}q_{2}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{m}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{m}}\hat{q}_{1}q_{3}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{m}\\ \vdots \\ log\; \hat{q}_{m}(z_{m})=\int _{z_{1}}\cdots \int _{z_{m-1}}\hat{q}_{1}\cdots \hat{q}_{m-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{m-1} logq^1(z1)=z2zmq2qmlogpθ(x,z)dz2dzmlogq^2(z2)=z1z3zmq^1q3qmlogpθ(x,z)dz1dz3dzmlogq^m(zm)=z1zm1q^1q^m1logpθ(x,z)dz1dzm1

方法:坐标上升

E L B O = E q ( z ) [ log ⁡ p θ ( x ( i ) , z ) q ( z ) ] = E q ( z ) [ log ⁡ p θ ( x ( i ) , z ) ] + H [ q ( z ) ] K L ( q ∣ ∣ p ) = ∫ q ( z ) ⋅ log ⁡ q ( z ) p θ ( z ∣ x ( i ) ) d z ELBO=E_{q_{(z)}}[\log\frac{p_\theta(x^{(i)},z)}{q_{(z)}}]\\=E_{q_{(z)}}[\log{p_\theta(x^{(i)},z)}]+H[{q_{(z)}}]\\ KL(q||p)=\int q(z)\cdot \log\frac{q(z)}{p_\theta(z|x^{(i)})}{d}z ELBO=Eq(z)[logq(z)pθ(x(i),z)]=Eq(z)[logpθ(x(i),z)]+H[q(z)]KL(qp)=q(z)logpθ(zx(i))q(z)dz

四、随机梯度变分推断(SGVI)

(一)直接求导

优化⽅法除了坐标上升,还有梯度上升的⽅式。
假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_{\phi }(Z) q(Z)=qϕ(Z),是和 ϕ \phi ϕ这个参数相连的概率分布。于是

a r g m a x q ( Z )    L ( q ) = a r g m a x ϕ    L ( ϕ ) \underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ) q(Z)argmaxL(q)=ϕargmaxL(ϕ)

其中

L ( ϕ ) = E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)] L(ϕ)=Eqϕ[logpθ(x,z)logqϕ(z)]

这里的 x x x表示的是样本

∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z ⏟ ① + ∫ q ϕ ( z ) ∇ ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z ⏟ ② \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}} ϕL(ϕ)=ϕEqϕ[logpθ(x,z)logqϕ(z)]=ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz= ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz+ qϕ(z)ϕ[logpθ(x,z)logqϕ(z)]dz

其中

② = ∫ q ϕ ( z ) ∇ ϕ [ l o g    p θ ( x , z ) ⏟ 与 ϕ 无 关 − l o g    q ϕ ( z ) ] d z = − ∫ q ϕ ( z ) ∇ ϕ l o g    q ϕ ( z ) d z = − ∫ q ϕ ( z ) 1 q ϕ ( z ) ∇ ϕ q ϕ ( z ) d z = − ∫ ∇ ϕ q ϕ ( z ) d z = − ∇ ϕ ∫ q ϕ ( z ) d z = − ∇ ϕ 1 = 0 ②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0 =qϕ(z)ϕ[ϕ logpθ(x,z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)dz=qϕ(z)qϕ(z)1ϕqϕ(z)dz=ϕqϕ(z)dz=ϕqϕ(z)dz=ϕ1=0

因此

∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ l o g    q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = E q ϕ [ ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] \nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))] ϕL(ϕ)==ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)[logpθ(x,z)logqϕ(z)]dz=Eqϕ[(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))]

这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:

z l ∼ q ϕ ( z ) E q ϕ [ ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] ∼ 1 L ∑ i = 1 L ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)) zlqϕ(z)Eqϕ[(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))]L1i=1L(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))

由于存在对数项,当趋近于 0 0 0时,微小的改变会导致很大的误差,所以采用重参数化技巧(Reparameterization Tick)。

(二)重参数化技巧

z = g ϕ ( ε , x ) , ε ∼ p ( ε ) z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ) z=gϕ(ε,x),εp(ε),对于 z ∼ q ϕ ( z ∣ x ) z\sim q_{\phi }(z|x) zqϕ(zx),可以得到 ∣ q ϕ ( z ∣ x ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right | qϕ(zx)dz=p(ε)dε。代入上式:

∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∇ ϕ ∫ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ∇ ϕ z ] = E p ( ε ) [ ∇ z ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ∇ ϕ g ϕ ( ε , x ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)] ϕL(ϕ)=ϕEqϕ[logpθ(x,z)logqϕ(z)]=ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz=ϕ[logpθ(x,z)logqϕ(z)]qϕ(z)dz=ϕ[logpθ(x,z)logqϕ(z)]p(ε)dε=ϕEp(ε)(logpθ(x,z)logqϕ(z)]=Ep(ε)[ϕ(logpθ(x,z)logqϕ(z))]=Ep(ε)[z(logpθ(x,z)logqϕ(z))ϕz]=Ep(ε)[z(logpθ(x,z)logqϕ(z))ϕgϕ(ε,x)]

进⾏蒙特卡洛采样,然后计算期望,得到梯度。
SGVI的迭代过程为:

ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ ) \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi ) ϕt+1ϕt+λtϕL(ϕ)

下一章传送门:

参考文章

转载地址:http://akpq.baihongyu.com/

你可能感兴趣的文章
MySQL 常用列类型
查看>>
mysql 常用命令
查看>>
Mysql 常见ALTER TABLE操作
查看>>
MySQL 常见的 9 种优化方法
查看>>
MySQL 常见的开放性问题
查看>>
Mysql 常见错误
查看>>
mysql 常见问题
查看>>
MYSQL 幻读(Phantom Problem)不可重复读
查看>>
mysql 往字段后面加字符串
查看>>
mysql 快照读 幻读_innodb当前读 与 快照读 and rr级别是否真正避免了幻读
查看>>
MySQL 快速创建千万级测试数据
查看>>
mysql 快速自增假数据, 新增假数据,mysql自增假数据
查看>>
MySql 手动执行主从备份
查看>>
Mysql 批量修改四种方式效率对比(一)
查看>>
Mysql 报错 Field 'id' doesn't have a default value
查看>>
MySQL 报错:Duplicate entry 'xxx' for key 'UNIQ_XXXX'
查看>>
Mysql 拼接多个字段作为查询条件查询方法
查看>>
mysql 排序id_mysql如何按特定id排序
查看>>
Mysql 提示:Communication link failure
查看>>
mysql 插入是否成功_PDO mysql:如何知道插入是否成功
查看>>