本文共 15833 字,大约阅读时间需要 52 分钟。
此文章主要是结合哔站shuhuai008大佬的白板推导视频:
全部笔记的汇总贴:
对于概率模型
从贝叶斯来看
p ( x ^ ∣ x ) = ∫ θ p ( x ^ , θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ , x ) p ( θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ x ) d θ = E θ ∣ x [ p ( x ^ ∣ θ ) ] p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}_\theta \\=\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}_\theta \\ \overset{}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}_\theta \\=E_{\theta |x}[p(\hat{x}|\theta )] p(x^∣x)=∫θp(x^,θ∣x)dθ=∫θp(x^∣θ,x)p(θ∣x)dθ=∫θp(x^∣θ)p(θ∣x)dθ=Eθ∣x[p(x^∣θ)]
Inference分为:
优化问题分为:
loss-function:无约束
L ( w ) = ∑ i = 1 N ∣ ∣ w T x i − y i ∣ ∣ 2 L(w)=\sum^{N}_{i=1}||w^Tx_i-y_i||^2 L(w)=∑i=1N∣∣wTxi−yi∣∣2
w ^ = arg min L ( w ) \hat{w}=\arg\min L(w) w^=argminL(w)
解法:
1.解析解:求导令为 0 0 0,得 w ∗ = ( X T X ) − 1 X T Y w^*=(X^TX)^{-1}X^TY w∗=(XTX)−1XTY
2.数值解:GD、SGD
f ( w ) = s i g n ( w T x + b ) f(w)=sign(w^Tx+b) f(w)=sign(wTx+b)
loss-function:有约束
min 1 2 w T w \min\frac{1}{2}w^Tw min21wTw
s . t . y i ( w T x i + b ) ≥ 1 , i = 1 , 2 , ⋯ , N s.t. \ y_i(w^Tx_i+b)\geq 1,i=1,2,\cdots,N s.t. yi(wTxi+b)≥1,i=1,2,⋯,N
connex优化 对偶
θ ^ = arg max log p ( x ∣ θ ) \hat{\theta}=\arg\max\log p(x|\theta) θ^=argmaxlogp(x∣θ)
θ ( t + 1 ) = arg max ∫ p ( x , z ∣ θ ) ⋅ p ( z ∣ x , θ ( t ) ) d z θ \theta^{(t+1)}=\underset{\theta}{\argmax\int p(x,z|\theta)\cdot p(z|x,\theta^{(t)}){d}z} θ(t+1)=θargmax∫p(x,z∣θ)⋅p(z∣x,θ(t))dz
Data:
x x x:observed variable → X : { x i } i = 1 N \rightarrow X:\left \{x_{i}\right \}_{i=1}^{N} →X:{ xi}i=1N
z z z:latent variable + parameter → Z : { z i } i = 1 N \rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N} →Z:{ zi}i=1N
( X , Z ) (X,Z) (X,Z):complete data
引入分布 q ( z ) q(z) q(z):
l o g p ( x ) = l o g p ( x , z ) − l o g p ( z ∣ x ) = l o g p ( x , z ) q ( z ) − l o g p ( z ∣ x ) q ( z ) log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)} logp(x)=logp(x,z)−logp(z∣x)=logq(z)p(x,z)−logq(z)p(z∣x)
式子两边同时对 q ( z ) q(z) q(z)求积分:
左边 = ∫ z q ( z ) ⋅ l o g p ( x ∣ θ ) d z = l o g p ( x ∣ θ ) ∫ z q ( z ) d z = l o g p ( x ∣ θ ) =\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta ) =∫zq(z)⋅logp(x∣θ)dz=logp(x∣θ)∫zq(z)dz=logp(x∣θ)
右边 = ∫ z q ( z ) l o g p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( E v i d e n c e L o w e r B o u n d ) − ∫ z q ( z ) l o g p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 =\underset{ELBO(Evidence\; Lower\; Bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}} =ELBO(EvidenceLowerBound) ∫zq(z)logq(z)p(x,z∣θ)dzKL(q(z)∣∣p(z∣x,θ)) −∫zq(z)logq(z)p(z∣x,θ)dz=变分 L(q)+≥0 KL(q∣∣p)
当 q q q与 p p p相等时, K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)等于 0 0 0,此时 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)取值最小,所以这时就是要使 L ( q ) L(q) L(q)越大越好:
q ~ ( z ) = a r g m a x q ( z ) L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) \tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x) q~(z)=q(z)argmaxL(q)⇒q~(z)≈p(z∣x)
我们对 q ( z q(z q(z)做以下假设,将多维变量的不同维度分为 M M M组,组与组之间而且是相互独立的,所以:
q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_{i}(z_{i}) q(z)=i=1∏Mqi(zi)
此时我们固定 q i ( z i ) , i ≠ j q_{i}(z_{i}),i\neq j qi(zi),i=j来求 q j ( z j ) q_{j}(z_{j}) qj(zj),所以:
L ( q ) = ∫ z q ( z ) l o g p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g q ( z ) d z ⏟ ② L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}} L(q)=① ∫zq(z)logp(x,z)dz−② ∫zq(z)logq(z)dz
对于 ① ① ①:
① = ∫ z ∏ i = 1 M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M = ∫ z j q j ( z j ) ( ∫ z − z j ∏ i ≠ j M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M ( i ≠ j ) ) ⏟ ∫ z − z j l o g p ( x , z ) ∏ i ≠ j M q i ( z i ) d z i d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g p ( x , z ) ] ⋅ d z j = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) ⋅ d z j ①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot log\; \hat{p}(x,z_{j})\cdot \mathrm{d}z_{j} ①=∫zi=1∏Mqi(zi)logp(x,z)dz1dz2⋯dzM=∫zjqj(zj)∫z−zjlogp(x,z)∏i=jMqi(zi)dzi ⎝⎛∫z−zji=j∏Mqi(zi)logp(x,z)(i=j)dz1dz2⋯dzM⎠⎞dzj=∫zjqj(zj)⋅E∏i=jMqi(zi)[logp(x,z)]⋅dzj=∫zjqj(zj)⋅logp^(x,zj)⋅dzj
对于 ② ② ②:
② = ∫ z q ( z ) l o g q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M l o g q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ l o g q 1 ( z 1 ) + l o g q 2 ( z 2 ) + ⋯ + l o g q M ( z M ) ] d z ②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z ②=∫zq(z)logq(z)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz
其中
∫ z ∏ i = 1 M q i ( z i ) l o g q 1 ( z 1 ) d z = ∫ z 1 z 2 ⋯ z M q 1 ( z 1 ) q 2 ( z 2 ) ⋯ q M ( z M ) ⋅ l o g q 1 ( z 1 ) d z 1 d z 2 ⋯ d z M = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 ⋅ ∫ z 2 q 2 ( z 2 ) d z 2 ⏟ = 1 ⋅ ∫ z 3 q 3 ( z 3 ) d z 3 ⏟ = 1 ⋯ ∫ z M q M ( z M ) d z M ⏟ = 1 = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 \int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1} ∫zi=1∏Mqi(zi)logq1(z1)dz=∫z1z2⋯zMq1(z1)q2(z2)⋯qM(zM)⋅logq1(z1)dz1dz2⋯dzM=∫z1q1(z1)logq1(z1)dz1⋅=1 ∫z2q2(z2)dz2⋅=1 ∫z3q3(z3)dz3⋯=1 ∫zMqM(zM)dzM=∫z1q1(z1)logq1(z1)dz1
也就是说
∫ z ∏ i = 1 M q i ( z i ) l o g q k ( z k ) d z = ∫ z k q k ( z k ) l o g q k ( z k ) d z k \int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k} ∫zi=1∏Mqi(zi)logqk(zk)dz=∫zkqk(zk)logqk(zk)dzk
则
② = ∑ i = 1 M ∫ z i q i ( z i ) l o g q i ( z i ) d z i = ∫ z j q j ( z j ) l o g q j ( z j ) d z j + C ②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C ②=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C
① − ② ①-②\; ①−②:
① − ② = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j + C ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ( z j ) ∣ ∣ p ^ ( x , z j ) ) ≤ 0 ①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0 ①−②=∫zjqj(zj)⋅logqj(zj)p^(x,zj)dzj+C∫zjqj(zj)⋅logqj(zj)p^(x,zj)dzj=−KL(qj(zj)∣∣p^(x,zj))≤0
当 q j ( z j ) = p ^ ( x , z j ) q_{j}(z_{j})=\hat{p}(x,z_{j}) qj(zj)=p^(x,zj)才能得到最⼤值。
在广义EM算法中,我们需要首先固定 θ \theta θ,然后求与 p p p最接近的 q q q,这里就可以使用变分推断的方法:
l o g p θ ( x ) = E L B O ⏟ L ( q ) + K L ( q ∣ ∣ p ) ⏟ ≥ 0 ≥ L ( q ) log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q) logpθ(x)=L(q) ELBO+≥0 KL(q∣∣p)≥L(q)
目标函数:
q ^ = a r g m i n q K L ( q ∣ ∣ p ) = a r g m a x q L ( q ) \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q) q^=qargminKL(q∣∣p)=qargmaxL(q)
l o g q j ( z j ) = E ∏ i ≠ j m q i ( z i ) [ l o g p θ ( x , z ) ] = ∫ z 1 ∫ z 2 ⋯ ∫ z j − 1 ∫ z j + 1 ⋯ ∫ z m q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q m ⋅ l o g p θ ( x , z ) d z 1 d z 2 ⋯ d z j − 1 d z j + 1 ⋯ d z m log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{m}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{m}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{m} logqj(zj)=E∏i=jmqi(zi)[logpθ(x,z)]=∫z1∫z2⋯∫zj−1∫zj+1⋯∫zmq1q2⋯qj−1qj+1⋯qm⋅logpθ(x,z)dz1dz2⋯dzj−1dzj+1⋯dzm
l o g q ^ 1 ( z 1 ) = ∫ z 2 ⋯ ∫ z m q 2 ⋯ q m ⋅ l o g p θ ( x , z ) d z 2 ⋯ d z m l o g q ^ 2 ( z 2 ) = ∫ z 1 ∫ z 3 ⋯ ∫ z m q ^ 1 q 3 ⋯ q m ⋅ l o g p θ ( x , z ) d z 1 d z 3 ⋯ d z m ⋮ l o g q ^ m ( z m ) = ∫ z 1 ⋯ ∫ z m − 1 q ^ 1 ⋯ q ^ m − 1 ⋅ l o g p θ ( x , z ) d z 1 ⋯ d z m − 1 log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{m}}q_{2}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{m}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{m}}\hat{q}_{1}q_{3}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{m}\\ \vdots \\ log\; \hat{q}_{m}(z_{m})=\int _{z_{1}}\cdots \int _{z_{m-1}}\hat{q}_{1}\cdots \hat{q}_{m-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{m-1} logq^1(z1)=∫z2⋯∫zmq2⋯qm⋅logpθ(x,z)dz2⋯dzmlogq^2(z2)=∫z1∫z3⋯∫zmq^1q3⋯qm⋅logpθ(x,z)dz1dz3⋯dzm⋮logq^m(zm)=∫z1⋯∫zm−1q^1⋯q^m−1⋅logpθ(x,z)dz1⋯dzm−1
方法:坐标上升
E L B O = E q ( z ) [ log p θ ( x ( i ) , z ) q ( z ) ] = E q ( z ) [ log p θ ( x ( i ) , z ) ] + H [ q ( z ) ] K L ( q ∣ ∣ p ) = ∫ q ( z ) ⋅ log q ( z ) p θ ( z ∣ x ( i ) ) d z ELBO=E_{q_{(z)}}[\log\frac{p_\theta(x^{(i)},z)}{q_{(z)}}]\\=E_{q_{(z)}}[\log{p_\theta(x^{(i)},z)}]+H[{q_{(z)}}]\\ KL(q||p)=\int q(z)\cdot \log\frac{q(z)}{p_\theta(z|x^{(i)})}{d}z ELBO=Eq(z)[logq(z)pθ(x(i),z)]=Eq(z)[logpθ(x(i),z)]+H[q(z)]KL(q∣∣p)=∫q(z)⋅logpθ(z∣x(i))q(z)dz
优化⽅法除了坐标上升,还有梯度上升的⽅式。
假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_{\phi }(Z) q(Z)=qϕ(Z),是和 ϕ \phi ϕ这个参数相连的概率分布。于是
a r g m a x q ( Z ) L ( q ) = a r g m a x ϕ L ( ϕ ) \underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ) q(Z)argmaxL(q)=ϕargmaxL(ϕ)
其中
L ( ϕ ) = E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)] L(ϕ)=Eqϕ[logpθ(x,z)−logqϕ(z)]
这里的 x x x表示的是样本
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ① + ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ② \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}} ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=① ∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz+② ∫qϕ(z)∇ϕ[logpθ(x,z)−logqϕ(z)]dz
其中
② = ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) ⏟ 与 ϕ 无 关 − l o g q ϕ ( z ) ] d z = − ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) d z = − ∫ q ϕ ( z ) 1 q ϕ ( z ) ∇ ϕ q ϕ ( z ) d z = − ∫ ∇ ϕ q ϕ ( z ) d z = − ∇ ϕ ∫ q ϕ ( z ) d z = − ∇ ϕ 1 = 0 ②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0 ②=∫qϕ(z)∇ϕ[与ϕ无关 logpθ(x,z)−logqϕ(z)]dz=−∫qϕ(z)∇ϕlogqϕ(z)dz=−∫qϕ(z)qϕ(z)1∇ϕqϕ(z)dz=−∫∇ϕqϕ(z)dz=−∇ϕ∫qϕ(z)dz=−∇ϕ1=0
因此
∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = E q ϕ [ ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] \nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))] ∇ϕL(ϕ)=①=∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]
这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:
z l ∼ q ϕ ( z ) E q ϕ [ ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] ∼ 1 L ∑ i = 1 L ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)) zl∼qϕ(z)Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]∼L1i=1∑L(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))
由于存在对数项,当趋近于 0 0 0时,微小的改变会导致很大的误差,所以采用重参数化技巧(Reparameterization Tick)。
取 z = g ϕ ( ε , x ) , ε ∼ p ( ε ) z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ) z=gϕ(ε,x),ε∼p(ε),对于 z ∼ q ϕ ( z ∣ x ) z\sim q_{\phi }(z|x) z∼qϕ(z∣x),可以得到 ∣ q ϕ ( z ∣ x ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right | ∣qϕ(z∣x)dz∣=∣p(ε)dε∣。代入上式:
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∇ ϕ ∫ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ∇ ϕ z ] = E p ( ε ) [ ∇ z ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ∇ ϕ g ϕ ( ε , x ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)] ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]p(ε)dε=∇ϕEp(ε)(logpθ(x,z)−logqϕ(z)]=Ep(ε)[∇ϕ(logpθ(x,z)−logqϕ(z))]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕz]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕgϕ(ε,x)]
进⾏蒙特卡洛采样,然后计算期望,得到梯度。
SGVI的迭代过程为:
ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ ) \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi ) ϕt+1←ϕt+λt⋅∇ϕL(ϕ)
下一章传送门:
参考文章
转载地址:http://akpq.baihongyu.com/