手帳と試行

学んだことをアウトプットしていきます。 日々、ノートあるのみ。

事後予測分布とWoodburyの公式

事後予測分布の式を、もう少し簡単にする。

Woodburyの公式

逆行列の計算に便利な次の公式がある。

(A+BDC)1=A1A1B(D1+CA1B)CA1\begin{aligned} (\bm A + \bm B \bm D \bm C)^{-1} = \bm A^{-1} - \bm A^{-1} \bm B (\bm D^{-1} + \bm C \bm A^{-1} \bm B) \bm C \bm A^{-1} \end{aligned}
ARN×NBRN×MCRM×NDRM×M\begin{aligned} \bm A \in{}& \R^{N \times N} \\ \bm B \in{}& \R^{N \times M} \\ \bm C \in{}& \R^{M \times N} \\ \bm D \in{}& \R^{M \times M} \end{aligned}

特に D=IM\bm D = \bm I_M の場合には次式に帰着する。

(A+BC)1=A1A1B(IM+CA1B)CA1\begin{aligned} (\bm A + \bm B \bm C)^{-1} = \bm A^{-1} - \bm A^{-1} \bm B (\bm I_M + \bm C \bm A^{-1} \bm B) \bm C \bm A^{-1} \end{aligned}

導出は省略する。

特に覚える必要はない。というか、使うときは徹底的に使いまくるので勝手に覚えることになる。

事後予測分布の式の書き換え

事後予測分布は次のようであった。

p(yX,X,y)=Nd(ymyy,Vyy)\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{d^\ast} (\bm y_\ast | \bm m_{y_\ast | y}, \bm V_{y_\ast | y}) \end{aligned}
{myy=1σ2VyyXV+Vd1mdVyy1=1σ2(Id1σ2XV+XT)V+1=1σ2XTX+Vd1md=Vd(1σ2XTy+V01m0)Vd1=1σ2XTX+V01\left\{\begin{aligned} \bm m_{y_\ast | y} ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} \bm X_\ast \bm V_{+} \bm V_d^{-1} \bm m_d \\ \bm V_{y_\ast | y}^{-1} ={}& \frac{1}{\sigma^2} \left( \bm I_{d^\ast} - \frac{1}{\sigma^2} \bm X_\ast \bm V_+ \bm X_\ast^\mathsf{T} \right) \\ \bm V_+^{-1} ={}& \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_d^{-1} \\ \bm m_d ={}& \bm V_d \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ \bm V_d^{-1} ={}& \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \end{aligned}\right.

出てくる記号が多くて計算が大変である。というか、逆行列計算が多すぎて、計算機を使うにしてもちょっと現実的ではない。

そこで、Woodburyの公式を適用してもう少し簡単な形に書き直していく。

1. 準備

まず Vd\bm V_dmd\bm m_d、さらに V+\bm V_+ を計算する。

Vd=(1σ2XTX+V01)1=V0V0XT(XV0XT+σ2Id)1XV0md=Vd(1σ2XTy+V01m0)=(V0V0XT(XV0XT+σ2Id)1XV0)(1σ2XTy+V01m0)=1σ2V0XT(Id(XV0XT+σ2Id)1XV0XT)(1)y+(InV0XT(XV0XT+σ2Id)1X)m0(1)=IdA1(XV0XTC+σ2IdD1)1XV0XTC=(IdA+1σ2XV0XTBDC)1=σ2(XV0XT+σ2Id)1=V0XT(XV0XT+σ2Id)1y+(InV0XT(XV0XT+σ2Id)1X)m0V+=(1σ2XTX+Vd1)1=VdVdXT(XVdXT+σ2Id)1XVd\begin{aligned} \bm V_d ={}& \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \right)^{-1} \\ ={}& \bm V_0 - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \bm V_0 \\ \\ \bm m_d ={}& \bm V_d \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ ={}& \left( \bm V_0 - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \bm V_0 \right) \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ ={}& \frac{1}{\sigma^2} \bm V_0 \bm X^\mathsf{T} \underbrace{ \left( \bm I_d - (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \bm V_0 \bm X^\mathsf{T} \right) }_{(1)} \bm y \\ & + \left( \bm I_n - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \right) \bm m_0 \\ &\left|\small\quad\begin{aligned} (1) ={}& \underbrace{\bm I_d}_{\bm A^{-1}} - ( \underbrace{ \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm C } + \underbrace{ \sigma^2 \bm I_d }_{\bm D^{-1}} )^{-1} \underbrace{ \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm C } \\ ={}& ( \underbrace{ \bm I_d }_{\bm A} + \underbrace{ \frac{1}{\sigma^2} \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm B \bm D \bm C} )^{-1} \\ ={}& \sigma^2 \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d \right)^{-1} \\ \end{aligned}\right. \\ ={}& \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d \right)^{-1} \bm y + \left( \bm I_n - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \right) \bm m_0 \\ \\ \bm V_+ ={}& \left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_d^{-1} \right)^{-1} \\ ={}& \bm V_d - \bm V_d \bm X_\ast^\mathsf{T} (\bm X_\ast \bm V_d \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast})^{-1} \bm X_\ast \bm V_d \\ \end{aligned}
2. 共分散行列の計算

続いて Vyy\bm V_{y_\ast | y} を計算する。

Vyy=σ2(Id1σ2XV+XT)1=σ2(IdX(σ2V+1+XTX)1XT)=X(V+11σ2XTX)1VdXT+σ2Id=XVdXT+σ2Id=XV0XT+σ2IdXV0XT(XV0XT+σ2Id)1XV0XT\begin{aligned} \bm V_{y_\ast | y} ={}& \sigma^2 \left( \bm I_{d^\ast} - \frac{1}{\sigma^2} \bm X_\ast \bm V_+\bm X_\ast^\mathsf{T} \right)^{-1} \\ ={}& \sigma^2 \left( \bm I_{d^\ast} - \bm X_\ast (-\sigma^2 \bm V_+^{-1} + \bm X_\ast^\mathsf{T} \bm X_\ast)^{-1} \bm X_\ast^\mathsf{T} \right) \\ ={}& \bm X_\ast \underbrace{\left( \bm V_+^{-1} - \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast \right)^{-1}}_{\bm V_d} \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} \\ ={}& \bm X_\ast \bm V_d \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} \\ ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \end{aligned}
3. 期待値の計算

最後に myy\bm m_{y_\ast | y}。これはちょっと面倒くさい。

myy=1σ2VyyXV+Vd1mdV+=VdVdXT(σ2Id+XVdXTVyy)1XVd=VdVdXTVyy1XVd=1σ2VyyX(InVdXTVyy1X)md=1σ2Vyy(IdXVdXTVyy1)Xmd=1σ2Vyy(VyyXVdXT)σ2IdVyy1Xmd=Xmd=XV0XT(XV0XT+σ2Id)1y+(XXV0XT(XV0XT+σ2Id)1X)m0\begin{aligned} \bm m_{y_\ast | y} ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} \bm X_\ast \bm V_+ \bm V_d^{-1} \bm m_d \\ &\left|\small\quad\begin{aligned} \bm V_+ ={}& \bm V_d - \bm V_d \bm X_\ast^\mathsf{T} (\underbrace{\sigma^2 \bm I_{d^\ast} + \bm X_\ast \bm V_d \bm X_\ast^\mathsf{T}}_{\bm V_{y_\ast | y}})^{-1} \bm X_\ast \bm V_d \\ ={}& \bm V_d - \bm V_d \bm X_\ast^\mathsf{T} \bm V_{y_\ast | y}^{-1} \bm X_\ast \bm V_d \end{aligned}\right. \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} \bm X_\ast (\bm I_n - \bm V_d \bm X_\ast^\mathsf{T} \bm V_{y_\ast | y}^{-1} \bm X_\ast) \bm m_d \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} (\bm I_{d^\ast} - \bm X_\ast \bm V_d \bm X_\ast^\mathsf{T} \bm V_{y_\ast | y}^{-1}) \bm X_\ast \bm m_d \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} \underbrace{(\bm V_{y_\ast | y} - \bm X_\ast \bm V_d \bm X_\ast^\mathsf{T})}_{\sigma^2 \bm I_{d^\ast}} \bm V_{y_\ast | y}^{-1}\bm X_\ast \bm m_d \\ ={}& \bm X_\ast \bm m_d \\ ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d \right)^{-1} \bm y \\ & + \left( \bm X_\ast - \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \right) \bm m_0 \\ \end{aligned}

計算の途中で myy=Xmd\bm m_{y_\ast | y} = \bm X_\ast \bm m_d というものが出てきた。これは、y=Xw+ε\bm y = \bm X \bm w + \bm \varepsilon において、

ymyy,XX,wmd\bm y \to \bm m_{y_\ast | y}, \quad\bm X \to \bm X_\ast, \quad\bm w \to \bm m_d

と書き換えたものに相当する。

このような置き換えはなかなか素朴なものに感じられるが、実際には「確率的モデル仮定→事後分布の導出→事後予測分布の導出」というステップをきちんと踏んだ上で得られた結果である。妥当性がそれなりに担保されているといえるだろう。

結果

ということで、事後予測分布は次のような形で書ける。

p(yX,X,y)=Nd(ymyy,Vyy)\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{d^\ast} (\bm y_\ast | \bm m_{y_\ast | y}, \bm V_{y_\ast | y}) \end{aligned}
{myy=XV0XT(XV0XT+σ2Id)1y+(XXV0XT(XV0XT+σ2Id)1X)m0Vyy=XV0XT+σ2IdXV0XT(XV0XT+σ2Id)1XV0XT\left\{\begin{aligned} \bm m_{y_\ast | y} ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d \right)^{-1} \bm y \\ & + \left( \bm X_\ast - \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm X \right) \bm m_0 \\ \bm V_{y_\ast | y} ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \\ \end{aligned}\right.

ちなみに

パラメータの事前分布の期待値が m0=0\bm m_0 = \bm 0 である場合、事後予測分布の期待値と共分散行列は

myy=XV0XT(XV0XT+σ2Id)1yVyy=XV0XT+σ2IdXV0XT(XV0XT+σ2Id)1XV0XT\begin{aligned} \bm m_{y_\ast | y} ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} \bm y \\ \bm V_{y_\ast | y} ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \\ \end{aligned}

となる。面白いことに、これはいくつかのかたまりに分けて解釈することができる。

myy=XV0XTKT(XV0XT+σ2Id)1(K+σ2Id)1yVyy=XV0XT+σ2IdK+σ2IdXV0XTKT(XV0XT+σ2Id)1(K+σ2Id)1XV0XTK\begin{aligned} \bm m_{y_\ast | y} ={}& \underbrace{ \bm X_\ast \bm V_0 \bm X^\mathsf{T} }_{\bm K_\ast^\mathsf{T}} \underbrace{ (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d)^{-1} }_{(\bm K + \sigma^2 \bm I_d)^{-1}} \bm y \\ \bm V_{y_\ast | y} ={}& \underbrace{ \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{d^\ast} }_{\bm K_{\ast\ast} + \bm \sigma^2 \bm I_{d^\ast}} - \underbrace{ \bm X_\ast \bm V_0 \bm X^\mathsf{T} }_{\bm K_\ast^\mathsf{T}} \underbrace{ ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_d )^{-1} }_{(\bm K + \sigma^2 \bm I_d)^{-1}} \underbrace{ \bm X \bm V_0 \bm X_\ast^\mathsf{T} }_{\bm K_\ast} \\ \end{aligned}

これらをあらかじめ計算しておけば、事後予測分布を

p(yX,X,y)=Nd(ymy,Vy)p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{d^\ast} (\bm y_\ast | \bm m_y, \bm V_y)
myy=KT(K+σ2Id)1yVyy=K+σ2IdKT(K+σ2Id)1K\begin{aligned} \bm m_{y_\ast | y} ={}& \bm K_\ast^\mathsf{T} (\bm K + \sigma^2 \bm I_d)^{-1} \bm y \\ \bm V_{y_\ast | y} ={}& \bm K_{\ast\ast} + \sigma^2 \bm I_{d^\ast} - \bm K_\ast^\mathsf{T} (\bm K + \sigma^2 \bm I_d)^{-1} \bm K_\ast \\ \end{aligned}
{K=XV0XTK=XV0XTK=XV0XT\left\{\begin{aligned} \bm K ={}& \bm X \bm V_0 \bm X^\mathsf{T} \\ \bm K_{\ast} ={}& \bm X \bm V_0 \bm X_\ast^\mathsf{T} \\ \bm K_{\ast\ast} ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} \\ \end{aligned}\right.

と、比較的シンプルな形で記述することができる。Woodburyの公式を適用する前のものに比べてば、圧倒的に計算が簡単になっていることがわかるだろう。

しかし、それでも計算量が大きいという問題が立ちはだかる。ボトルネックは逆行列 K1\bm K^{-1} の計算部分であり、時間計算量が O(d3)\mathcal O(d^3)、空間計算量が O(d2)\mathcal O(d^2) だけ要求される。そのためデータの個数 dd が大きくなると、現実的な時間では計算が終わらなくなるおそれがある。