手帳と試行

学んだことをアウトプットしていきます。 日々、ノートあるのみ。

事後予測分布

未知の入力値に対する出力を確率的に計算するために、事後予測分布を計算する。

事後予測分布

確率モデルにおいて、出力データ D\mathcal D が与えられたという条件のもとで、さらに未知のデータ D\mathcal D_\ast が従う分布 p(DD)p(\mathcal D_\ast | \mathcal D)事後予測分布 (posterior predictive distribution) という。

この分布は、次式のように、既知のデータ D\mathcal D から計算されるパラメータ θ\theta の事後分布 p(θD)p(\theta | \mathcal D) にモデル p(Dθ)p(\mathcal D_\ast | \mathcal \theta) を掛け、さらに θ\theta について積分して周辺化したものとして定義される。

p(DD)dθp(Dθ)p(θD)\begin{aligned} p(\mathcal D_\ast | \mathcal D) \coloneqq \int d\theta p(\mathcal D_\ast | \theta) p(\theta | \mathcal D) \end{aligned}

回帰モデルにおいては次のような形をしている。

p(yX,X,y)dwp(yX,w)p(wX,y)\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) \coloneqq \int d \bm w p(\bm y_\ast | \bm X_\ast, \bm w) p(\bm w | \bm X, \bm y) \end{aligned}

ただし XRd×n,yRd\bm X_\ast \in \R^{d_\ast \times n}, \bm y_\ast \in \R^{d_\ast} は未知の入出力データであり、次のようなものとする。

X=[x1x2xd],y=[y1y2yd]\begin{aligned} \bm X_\ast ={}& \left[\begin{darray}{c} \bm x_{1} \\ \bm x_{2} \\ \vdots \\ \bm x_{d^\ast} \\ \end{darray}\right] ,& \bm y_\ast ={}& \left[\begin{darray}{c} y_{1} \\ y_{2} \\ \vdots \\ y_{d^\ast} \\ \end{darray}\right] \end{aligned}

事後予測分布を計算することで、出力が未知であるような入力データに対する出力データを、不確かさも含めて分布として推定することができる。そこで、事後予測分布を用いるという方法によって線形回帰を行なうものを、俗にベイズ線形回帰 (Bayesian linear regression) などと言ったりする。

ただし、事後予測分布の計算が解析的に可能な場合は限られており、多くの場合はMCMCなどにより近似的に数値計算するのが現実的である。

具体的な計算

具体例を見てみよう。

尤度関数の仮定

まず、尤度関数を以下のようなものにする。

p(yX,w)=Nd(yXw,σ2Id)\begin{aligned} p(\bm y | \bm X, \bm w) = \mathcal N_d(\bm y | \bm X \bm w, \sigma^2 \bm I_d) \end{aligned}

さらに、未知のデータ X,y\bm X_\ast, \bm y_\ast についても、同じ分布に従うものと考え、分布 p(yX,w)p(\bm y_\ast | \bm X_\ast, \bm w) を次式で定める。

p(yX,w)=Nd(yXw,σ2Id)p(\bm y_\ast | \bm X_\ast, \bm w) = \mathcal N_{d^\ast} (\bm y_\ast | \bm X_\ast \bm w, \sigma^2 \bm I_{d^\ast})

ただし、今回は未知のデータの個数を dd^\ast 個として、XRd×n,yRd\bm X_\ast \in \R^{d^\ast \times n}, \bm y_\ast \in \R^{d^\ast} としておこう。

事後分布の計算

続いて、パラメータの事前分布として正規分布を仮定する。

p(w)=Nn(wm0,V0)exp(12(wm0)TV01(wm0))\begin{aligned} p(\bm w) &= \mathcal N_n(\bm w | \bm m_0, \bm V_0) \\ &\propto \exp \left( -\frac{1}{2} (\bm w - \bm m_0)^\mathsf{T} \bm V_0^{-1} (\bm w - \bm m_0) \right) \end{aligned}

この場合、パラメータの事後分布は次のような正規分布になる。

p(wX,y)=Nn(wmd,Vd)p(\bm w | \bm X, \bm y) = \mathcal N_n(\bm w | \bm m_d, \bm V_d)
{md=Vd(1σ2XTy+V01m0)Vd1=1σ2XTX+V01\left\lbrace\begin{aligned} \bm m_d &= \bm V_d \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ \bm V_d^{-1} &= \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \end{aligned}\right.
計算

材料が揃ったので、事後予測分布の計算に取り掛かる。

以下の計算においては、w\bm w および y\bm y_\ast に関係のない項をすべて const.\mathrm{const.} として纏めていることに注意。

p(yX,X,y)=dwp(yX,w)p(wX,y)=dwNd(yXw,σ2Id)Nn(wmd,Vd)dwexp(12σ2yXw22)exp(12(wmd)Vd1(wmd))=dwexp(12(1σ2yXw22+(wmd)Vd1(wmd))(1))(1)=1σ2wTXTXw21σ2wTXTy+1σ2yTy+wTVd1w2wTVd1md+mdTVd1mdconst.=wT(1σ2XTX+Vd1)V+1w2wT(1σ2XTy+Vd1md)V+1m++1σ2yTy+const.=wTV+1w2wTV+1m++1σ2yTy+const.=(wm+)TV+1(wm+)+1σ2yTym+V+1m+(2)+const.(2)=1σ2yTym+V+1m+=1σ2yTyyT(1σ2)2XV+XTy2yT1σ2XV+Vd1md+const.=yT(1σ2Id(1σ2)2XV+XT)Vyy1y2yT1σ2XV+Vd1mdVyy1myy+const.=yTVyy1y2yTVyy1myy+const.=(ymyy)TVyy1(ymyy)+const.dwexp(12(wm+)TV+1(wm+)12(ymyy)TVyy1(ymyy))=dwNn(wm+,V+)Nd(ymyy,Vyy)=Nd(ymyy,Vyy)\begin{aligned} &\hspace{-1pc} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) \\ ={}& \int d \bm w p(\bm y_\ast | \bm X_\ast, \bm w) p(\bm w | \bm X, \bm y) \\ ={}& \int d \bm w \mathcal N_{d^\ast}(\bm y_\ast | \bm X_\ast \bm w, \sigma^2 \bm I_{d^\ast}) \mathcal N_{n}(\bm w | \bm m_d, \bm V_d) \\ \propto{}& \int d \bm w \exp \left( -\frac{1}{2\sigma^2} \| \bm y_\ast - \bm X_\ast \bm w \|_2^2 \right) \exp \left( -\frac{1}{2}(\bm w - \bm m_d) \bm V_d^{-1} (\bm w - \bm m_d) \right) \\ ={}& \int d \bm w \exp \left( -\frac{1}{2} \underbrace{\left( \frac{1}{\sigma^2} \| \bm y_\ast - \bm X_\ast \bm w \|_2^2 + (\bm w - \bm m_d) \bm V_d^{-1} (\bm w - \bm m_d) \right)}_{(1)} \right) \\ &\left|\small\quad\begin{aligned} (1) ={}& \frac{1}{\sigma^2} \bm w^\mathsf{T} \bm X_\ast^\mathsf{T} \bm X_\ast \bm w - 2 \frac{1}{\sigma^2} \bm w^\mathsf{T} \bm X_\ast^\mathsf{T} \bm y_\ast + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast \\ & + \bm w^\mathsf{T} \bm V_d^{-1} \bm w - 2 \bm w^\mathsf{T} \bm V_d^{-1} \bm m_d + \underbrace{\bm m_d^\mathsf{T} \bm V_d^{-1} \bm m_d}_\mathrm{const.} \\ ={}& \bm w^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_d^{-1} \right)}_{\bm V_{+}^{-1}} \bm w - 2 \bm w^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm y_\ast + \bm V_d^{-1} \bm m_d \right)}_{\bm V_{+}^{-1} \bm m_{+}} + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast + \mathrm{const.} \\ ={}& \bm w^\mathsf{T} \bm V_{+}^{-1} \bm w - 2 \bm w^\mathsf{T} \bm V_{+}^{-1} \bm m_{+} + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast + \mathrm{const.} \\ ={}& (\bm w - \bm m_{+})^\mathsf{T} \bm V_{+}^{-1} (\bm w - \bm m_{+}) + \underbrace{ \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm m_{+} \bm V_{+}^{-1} \bm m_{+} }_{(2)} + \mathrm{const.} \\ (2) ={}& \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm m_{+} \bm V_{+}^{-1} \bm m_{+} \\ ={}& \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm y_\ast^\mathsf{T} \left( \frac{1}{\sigma^2} \right)^2 \bm X_\ast \bm V_{+} \bm X_\ast^\mathsf{T} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \frac{1}{\sigma^2} \bm X_\ast \bm V_{+} \bm V_d^{-1} \bm m_d + \mathrm{const.} \\ ={}& \bm y_\ast^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm I_{d^\ast} - \left( \frac{1}{\sigma^2} \right)^2 \bm X_\ast \bm V_{+} \bm X_\ast^\mathsf{T} \right)}_{\bm V_{y_\ast | y}^{-1}} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \underbrace{\frac{1}{\sigma^2} \bm X_\ast \bm V_{+} \bm V_d^{-1} \bm m_d}_{\bm V_{y_\ast | y}^{-1} \bm m_{y_\ast | y}} + \mathrm{const.} \\ ={}& \bm y_\ast^\mathsf{T} \bm V_{y_\ast | y}^{-1} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \bm V_{y_\ast | y}^{-1} \bm m_{y_\ast | y} + \mathrm{const.} \\ ={}& (\bm y_\ast - \bm m_{y_\ast | y})^\mathsf{T} \bm V_{y_\ast | y}^{-1} (\bm y_\ast - \bm m_{y_\ast | y}) + \mathrm{const.} \\ \end{aligned}\right. \\ \propto{}& \int d \bm w \exp \left( - \frac{1}{2} (\bm w - \bm m_{+})^\mathsf{T} \bm V_{+}^{-1} (\bm w - \bm m_{+}) - \frac{1}{2} (\bm y_\ast - \bm m_{y_\ast | y})^\mathsf{T} \bm V_{y_\ast | y}^{-1} (\bm y_\ast - \bm m_{y_\ast | y}) \right) \\ ={}& \int d \bm w \mathcal N_n (\bm w | \bm m_{+}, \bm V_{+}) \mathcal N_{d^\ast} (\bm y_\ast | \bm m_{y_\ast | y}, \bm V_{y_\ast | y}) \\ ={}& \mathcal N_{d^\ast} (\bm y_\ast | \bm m_{y_\ast | y}, \bm V_{y_\ast | y}) \\ \end{aligned}

こうして事後予測分布が計算された。

p(yX,X,y)=Nd(ymyy,Vyy)\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{d^\ast} (\bm y_\ast | \bm m_{y_\ast | y}, \bm V_{y_\ast | y}) \end{aligned}
{myy=1σ2VyyXV+Vd1mdVyy1=1σ2(Id1σ2XV+XT)V+1=1σ2XTX+Vd1md=Vd(1σ2XTy+V01m0)Vd1=1σ2XTX+V01\left\{\begin{aligned} \bm m_{y_\ast | y} ={}& \frac{1}{\sigma^2} \bm V_{y_\ast | y} \bm X_\ast \bm V_{+} \bm V_d^{-1} \bm m_d \\ \bm V_{y_\ast | y}^{-1} ={}& \frac{1}{\sigma^2} \left( \bm I_{d^\ast} - \frac{1}{\sigma^2} \bm X_\ast \bm V_+ \bm X_\ast^\mathsf{T} \right) \\ \bm V_+^{-1} ={}& \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_d^{-1} \\ \bm m_d ={}& \bm V_d \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ \bm V_d^{-1} ={}& \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \end{aligned}\right.