数学 機械学習/AI

混合密度ネットワーク

2023年10月13日

 

 

 

 

 

 

混合密度ネットワークとは

混合密度ネットワーク(MDN)は、複数の確率分布を組み合わせて出力するニューラルネットワークの一種である。与えられたどんな\(\mathbf{x}\)の値に対しても、混合モデルは任意の条件付き密度関数\(p(\mathbf{t} \mid \mathbf{x})\)をモデル化するための一般的な枠組みを提供する。

 

そもそも教師あり学習の目標とは、条件付き分布\(p(\mathbf{t} \mid \mathbf{x})\)をモデル化することだが、多くの単純な回帰問題ではこれはガウス分布であると仮定されている。のであった

しかし、実用的にはガウス分布と異なる分布を使用しなければならない場面が多く、特にガウス分布では多峰性を持つ問題に対応できない。多峰性は言い換えるとそれらしい答えが複数あるような問題と考えることができる。

混合密度ネットワークと運動学の関わり

例えば、どのような場合で分布が多峰性を持つのかというと、順運動学(forward kinematics(FK))逆運動学(inverse kinematics(IK))という運動学で用いられる概念が簡単な例として挙げられる。

FKとIKというのは運動学はもちろん3D2D問わずアニメーションを制作する人やロボティクスな分野の人にとっては馴染み深い言葉で、どういうものかというと以下のようなものである。

順運動学(forward kinematics(FK))逆運動学(inverse kinematics(IK))
与えられた始点、アームの長さ、関節の回転角から終端位置を求める問題で解は一意に定まる。FKの逆で終端位置から回転角を推定する問題で、解は上の図のような場合二つある。実際のパターン認識ではこちらのほうが多い。

順運動学は通常一意に定まるが、逆運動学は多くの場合に多解性を持つ。例えば、ロボットアームがある点に到達するためには、ジョイントの角度の組み合わせが複数存在する場合がある。このような多解性を扱うために、混合密度ネットワークが有用である。

混合密度ネットワーク(MDN)のメリット

  1. 多解性の表現: MDNは多解性を自然に扱うことができる。複数の解が確率的に生成されるため、状況に応じて最適な解を選択することができる。
  2. 高度な非線形性: 通常のニューラルネットワーク同様、MDNも高度な非線形関数を近似する能力がある。これにより、複雑な関係性も学習することが可能である。
  3. 確率的な推論: 出力が確率分布として得られるため、不確実性のある状況でもその不確実性を量化することができる。

人工的なデータで見てみる

\(\left\{x_n\right\}\):区間(0,1)に一様分布する確率変数\(x\)をサンプリングしたもの
\(\left\{t_n\right\}\):関数\(x_n+0.3 \sin \left(2 \pi x_n\right)\)に区間(-0.1,0.1)の一様乱数を加えて生成されたもの

この同一のデータに対してxとtの役割を入れ替えることで逆問題が得られる。

順問題(左)と逆問題(右)のデータ集合(緑の丸点)と、6個の隠れユニットと単一の線形出力ユニットを持つ二層ニューラルネットワークで二乗和誤差関数を最小化して得られるフィッティング(赤線)をあらわした図

最小二乗法はガウス分布の仮定の下では最尤推定に相当し、逆問題のような非ガウス性を持つものについては、右の逆問題の図の赤線を見てもらうとわかるのだが、しょぼい結果しか得られない。

 

ではどうするのかというと、ここで出てくるのが混合モデルである。

ガウス分布は一つの山を予測するものだが、混合ガウス分布は複数の山を予測するイメージである


画像引用:https://xtrend.nikkei.com/atcl/contents/18/00076/00009/

混合密度ネットワークの数学的モデル

MDNのモデル

$$
p(\mathbf{t} \mid \mathbf{x})=\sum_{k=1}^K \pi_k(\mathbf{x}) \mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k(\mathbf{x}), \sigma_k^2(\mathbf{x})\right)
$$

  • \(p(\mathbf{t} \mid \mathbf{x})\): これは、入力\(\mathbf{x}\)が与えられたときの出力\(\mathbf{t}\)の確率分布を表している。
  • \(\pi_k(\mathbf{x})\): これは\(k\)番目のガウス分布の「重要度」を表す。すべての\(\pi_k(\mathbf{x})\)を足し合わせると1になる。
  • \(\mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k(\mathbf{x}), \sigma_k^2(\mathbf{x})\right)\): これは\(k\)番目のガウス分布。平均値は\(\boldsymbol{\mu}_k(\mathbf{x})\)、分散は\(\sigma_k^2(\mathbf{x})\)である。

なぜこれが便利なのか?

  1. 多様な出力: このモデルは複数のガウス分布を組み合わせるため、一つの入力\(\mathbf{x}\)に対して複数の出力\(\mathbf{t}\)を持つことができる。これが多解性の問題に役立つ。
  2. 異分散の対応: 分散\(\sigma_k^2(\mathbf{x})\)が入力\(\mathbf{x}\)に依存する形になっているため、入力によって出力のばらつきが変わるような状況にも対応できる。
  3. 一般的な共分散行列の拡張: コレスキー分解などを使うことで、さまざまな形のばらつきにも対応できる。
  4. 条件付き分布の非分解性: 普通のモデルでは、出力\(\mathbf{t}\)の各要素が独立であると仮定することが多い。しかし、このモデルではそのような仮定をしていない。つまり、出力\(\mathbf{t}\)の各要素がどう関連しているかも学習できる。
コレスキー分解の部分について

コレスキー分解とは

コレスキー分解(Cholesky decomposition)は、正定値対称行列を下三角行列とその転置行列の積に分解する手法である。この手法は、線形代数や統計学、機械学習など、多くの分野で活用される。

ある正定値対称行列(すなわち、全ての固有値が正で、かつ転置行列が元の行列と等しい行列)\(A=\left[a_{i j}\right] \in \mathbb{R}^{n \times n}\)に対して、コレスキー分解は以下のように表現される。

$$
A=L L^T
$$

ここで、\(L=\left[l_{i j}\right] \in \mathbb{R}^{n \times n}\)は下三角行列(すなわち、対角線より上の成分が全て0の行列)であり、\(L^T\)はその転置行列である。

コレスキー分解の計算について

$$
A=\left[\begin{array}{cccc}
a_{11} & a_{21} & \cdots & a_{n 1} \\
a_{21} & a_{22} & \cdots & a_{n 2} \\
\vdots & \vdots & \ddots & \vdots \\
a_{n 1} & a_{n 2} & \cdots & a_{n n}
\end{array}\right], \quad L=\left[\begin{array}{cccc}
l_{11} & 0 & \cdots & 0 \\
l_{21} & l_{22} & \cdots & \vdots \\
\vdots & \vdots & \ddots & 0 \\
l_{n 1} & l_{n 2} & \cdots & l_{n n}
\end{array}\right] $$

として

$$
L L^{\mathrm{T}}=\left[\begin{array}{cccc}
l_{11} & 0 & \cdots & 0 \\
l_{21} & l_{22} & \cdots & \vdots \\
\vdots & \vdots & \ddots & 0 \\
l_{n 1} & l_{n 2} & \cdots & l_{n n}
\end{array}\right]\left[\begin{array}{cccc}
l_{11} & l_{21} & \cdots & l_{n 1} \\
0 & l_{22} & \cdots & l_{n 2} \\
\vdots & \vdots & \ddots & \vdots \\
0 & \cdots & 0 & l_{n n}
\end{array}\right]=\left[\begin{array}{cccc} l_{11}^2 & l_{11} l_{21} & \cdots & l_{11} l_{n 1} \\ l_{11} l_{21} & l_{21}^2+l_{22}^2 & \cdots & l_{21} l_{n 1}+l_{22} l_{n 2} \\ \vdots & \vdots & \ddots & \vdots \\ l_{11} l_{n 1} & l_{21} l_{n 1}+l_{22} l_{n 2} & \cdots & l_{n 1}^2+l_{n 2}^2+\cdots+l_{n n}^2 \end{array}\right] $$

で両辺の係数を比較して

$$
\begin{aligned}
& a_{11}=l_{11}^2 \\
& a_{21}=l_{11} l_{21} \quad a_{22}=l_{21}^2+l_{22}^2 \\
& \vdots \quad \vdots \quad \ddots \\
& a_{n 1}=l_{11} l_{n 1} \quad a_{n 2}=l_{21} l_{n 1}+l_{22} l_{n 2} \quad \cdots \quad a_{n n}=l_{n 1}^2+l_{n 2}^2+\cdots+l_{n n}^2 \\
&
\end{aligned}
$$

対角成分を求める際には、平方根の計算が含まれるのだが、その対角成分をすべて正として計算すれば一意に逆行列を計算することができる。

\(L\)の各成分は一列目から順に左から計算できて

$$
\begin{aligned}
& l_{11}=\sqrt{a_{11}} \\
& l_{21}=\frac{a_{21}}{l_{11}} \quad l_{22}=\sqrt{a_{22}-l_{21}^2} \\
& \vdots \quad \vdots \quad \ddots \\
& l_{n 1}=\frac{a_{n 1}}{l_{11}} \quad l_{n 2}=\frac{a_{n 2}-l_{21} l_{n 1}}{l_{22}} \quad \cdots \quad l_{n n}=\sqrt{a_{n n}-\left(l_{n 1}^2+l_{n 2}^2+\cdots+l_{n(n-1)}^2\right)} \\
&
\end{aligned}
$$

となる。

これは後でも解説するが、混合ガウス分布のパラメータをニューラルネットワークで推定する場合、出力パラメータは各々に関する制約条件を常に満たしている必要がある。さもなければ学習中にパラメータは不正な値となり、学習が破綻する。そのため、出力層の構造を工夫する必要がある。

ガウス分布の平均\(\mu_i\)は実数すべてを取りうるから、対応する出力層の重みづけ和\(a_{\boldsymbol{\mu}_i}\)について活性化は行わない。

ガウス分布の混合の重み\(\pi_i\)は正でありかつ総和が1である必要がある。そのため、対応する出力層の重みづけ和\(a_{\pi_i}\)はsoftmax関数による活性化

$$
\pi_i=\frac{\exp \left(a_{\pi_i}\right)}{\sum_i \exp \left(a_{\pi_i}\right)}
$$

を行う。多変量正規分布の確率密度関数は次のように定義されるのだったが

$$
\mathcal{N}(\mathbf{x} \mid \boldsymbol{\mu}, \boldsymbol{\Sigma})=\frac{1}{(2 \pi)^{d / 2} \sqrt{|\boldsymbol{\Sigma}|}} \exp \left(-\frac{1}{2}(\mathbf{x}-\boldsymbol{\mu})^T \boldsymbol{\Sigma}^{-1}(\mathbf{x}-\boldsymbol{\mu})\right)
$$

共分散行列\(\Sigma_i^{-1}\)については正定値対称性が満たされる必要があるため、ネットワークの出力をそのままこの行列とはせずに、以下のようにすればいい。

正定値対称行列である\(\Sigma_i^{-1}\)は下三角行列

$$
L=\left(\begin{array}{rrr}
l_{00} & 0 & 0 \\
l_{10} & l_{11} & 0 \\
l_{20} & l_{21} & l_{22}
\end{array}\right)
$$

を用いて

$$
\Sigma_i^{-1}=L L^{\mathrm{T}}
$$

とコレスキー分解される。一方、このような\(L\)から求められる\(\Sigma_i^{-1}\)は必ず正定値対称となる。この性質を利用し、ネットワークは\(\Sigma_i^{-1}\)の代わりに下三角行列\(L\)を推定する。ただし、対角行列\(l_{i i}\)については正の値となるよう指数関数による活性化を行い、それ以外の要素については活性化を行わない。これより

$$
det\left(\Sigma_i^{-1}\right)=det\left(L L^{\mathrm{T}}\right)=det(L) det\left(L^{\mathrm{T}}\right)
$$

から

$$
\sqrt{det\left(\Sigma_i^{-1}\right)}=l_{00} l_{11} l_{22}
$$

と求められる。

 

まず基本的な概念から入る。ここで考えるのは「ニューラルネットワーク」が「確率モデル」をどのように作り出すかだ。通常、ニューラルネットワークは入力に対して一つの出力を予測する。例えば、手書き数字(入力)に対してその数字が何か(出力)を予測するような場合だ。しかし、この混合密度ネットワークは一歩進んで、入力に対して「確率的な出力」を与える。

 

混合モデルのニューラルネットワークの役割

通常のニューラルネットワークが、入力に対して一つの出力を与えるのに対して、混合密度ネットワークは入力に対してこれらの「重み」や「平均値」、「分散」を出力する。

例えば、ニューラルネットワークが一つの隠れ層としてシグモイド関数を持つとしよう。この場合、ネットワークの出力ユニットは以下のようになる。

  • \(\pi_k(\mathbf{x})\)(混合係数)を決定するK個の出力ユニット: \(a_k^\pi\)
  • \(\sigma_k(\mathbf{x})\)(カーネルの幅、つまり分散)を決定するK個の出力ユニット: \(a_k^\sigma\)
  • \(\mu_{k j}(\mathbf{x})\)(カーネルの中心、つまり平均)を決定するK × L個の出力ユニット: \(a_{k j}^\mu\)

通常のネットワークが目標変数に対してL個しか出力しないのに対し、合密度ネットワークは\((L + 2)K\)個の出力を持つ。これにより、入力に対して確率的な出力が得られるわけだ。

総じて、混合密度ネットワークは確率的な予測を行うための高度な手法であり、それぞれの出力ユニットが特定の役割を果たしている。

パラメータの生成

混合係数

混合係数は

$$
\sum_{k=1}^K \pi_k(\mathbf{x})=1, \quad 0 \leq \pi_k(\mathbf{x}) \leq 1
$$

という制約条件を満たす必要があるので、これはsoftmax関数の出力

$$
\pi_k(\mathbf{x})=\frac{\exp \left(a_k^\pi\right)}{\sum_{l=1}^K \exp \left(a_l^\pi\right)}
$$

を用いれば達成される。

分散

分散の制約条件は\(\sigma_k^2(\mathbf{x}) \geq 0\)なので、指数関数を用いればよく

$$
\sigma_k(\mathbf{x})=\exp \left(a_k^\sigma\right)
$$

平均

平均は実数の要素を持つだけなので、ネットワークの出力をそのまま移す恒等関数でよい

$$
\mu_{k j}(\mathbf{x})=a_{k j}^\mu
$$

 

混合密度ネットワークの誤差関数

誤差関数\(E(\mathbf{w})\)は、モデルの予測がどれだけ実際のデータから外れているかを数値で示す関数である。この関数を最小化するようにニューラルネットワークのパラメータ(重みとバイアス)\(\mathbf{w}\)を調整する。

MDNの誤差関数は以下のように定義される。

MDNの誤差関数

$$
E(\mathbf{w})=-\sum_{n=1}^N \ln \left\{\sum_{k=1}^K \pi_k\left(\mathbf{x}_n, \mathbf{w}\right) \mathcal{N}\left(\mathbf{t}_n \mid \boldsymbol{\mu}_k\left(\mathbf{x}_n, \mathbf{w}\right), \sigma_k^2\left(\mathbf{x}_n, \mathbf{w}\right) \mathbf{I}\right)\right\}
$$

ここで、\(N\)はデータの数、\(K\)は混合モデルの成分数である。\(\pi_k\), \(\boldsymbol{\mu}_k\), \(\sigma_k^2\)は、それぞれ混合係数、平均、分散を表している。これらはニューラルネットワークの出力であり、入力\(\mathbf{x}\)とパラメータ\(\mathbf{w}\)に依存する。

この式は少し複雑だが、基本的には以下のような手順で計算される。

  1. 各データ点\(\mathbf{x}_n\)に対して、モデルが予測する確率分布を計算する。
  2. この確率分布を使って、実際のターゲット\(\mathbf{t}_n\)がどれだけ確率的に起こり得るかを計算する。
  3. それを全データ点で足し合わせ、負の対数をとる。

尤度を最大化する(つまり、\(E(\mathbf{w})\)を最小化する)パラメータ\(\mathbf{w}\)(ニューラルネットワークの重みとバイアス)を見つけるのが目的である。

逆伝播について

逆伝播(Backpropagation)は、ニューラルネットワークの訓練において、誤差関数の微分(勾配)を効率よく計算するためのアルゴリズムである。

MDNのネットワーク構造は通常のニューラルネットワークと同じなので、通常の逆伝播の手続きを用いて、誤差関数\(E(\mathbf{w})\)の微分を計算することができる。具体的には、各出力ユニットで計算される誤差信号\(\delta\)が、隠れユニットへと逆伝播される。これにより、誤差関数に対する各パラメータの微分が計算される。

誤差信号\(\delta\)は、出力層で計算され、入力層に向かって逆方向に伝播していく。この\(\delta\)を用いて、各重みに関する誤差関数の微分が求められる。全てのデータポイントに対してこの操作を行い、合計することで、誤差関数\(E(\mathbf{w})\)に対する全体の微分が求まる。

 

\(\gamma\)の意味とその扱い

これから、混合密度ネットワークの文脈で登場する\(\gamma_k(\mathbf{t} \mid \mathbf{x})\)は、事後分布と呼ばれるものである。この事後分布は、入力\(\mathbf{x}\)と対応する目標\(\mathbf{t}\)が与えられた条件下で、そのデータ点が\(k\)番目のガウス分布から生成された確率を示している。

事前分布とは、何も観測しない状態での信念や確率分布を表す。混合係数\(\pi_k(\mathbf{x})\)は、事前における各成分(ガウス分布)\(k\)がデータ点に対して持つ「重み」や「確率」を表している。言い換えれば、\(\pi_k\)は入力\(\mathbf{x}\)が与えられたときに、そのデータ点が\(k\)番目のガウス分布から生成される確率である。

対して、事後分布は観測データを考慮に入れた後の確率分布である。この事後分布\(\gamma_k(\mathbf{t} \mid \mathbf{x})\)は、データ点が実際に観測された後で、そのデータ点が\(k\)番目のガウス分布から生成された確率を更新して示している。

$$
\gamma_k(\mathbf{t} \mid \mathbf{x})=\frac{\pi_k \mathcal{N}_{n k}}{\sum_{l=1}^K \pi_l \mathcal{N}_{n l}}
$$

この式は、Bayesの定理に基づいている。具体的には、\(\pi_k\)が事前確率、\(\mathcal{N}_{n k}\)尤度、\(\gamma_k\)が事後確率となる。分母の\(\sum_{l=1}^K \pi_l \mathcal{N}_{n l}\)は全確率の法則による正規化項で、すべての成分を通じての確率の合計が1になるようにしている。

この事後分布\(\gamma_k\)を用いると、モデルが出力する各ガウス分布がどれだけデータ点に「フィット」しているのかがわかる。この値が大きいほど、その成分(ガウス分布)がそのデータ点に対してより「適している」と解釈できる。この事後分布を計算することで、より柔軟かつ効率的な学習や予測が可能になる。

 

各パラメータの微分

モデルを訓練するためには、誤差関数\(E_n\)を最小化する必要がある。誤差関数の形は複雑であるが、それを各パラメータで微分することで、どの方向にパラメータを更新すれば誤差が減少するのかがわかる。

混合係数\(\pi_k\)に対する微分

$$
\frac{\partial E_n}{\partial a_k^\pi}=\pi_k-\gamma_k
$$

証明

微分のchain ruleから

$$
\frac{\partial E_n}{\partial a_k^\pi}=\sum_{j=1}^K \frac{\partial E_n}{\partial \pi_j} \frac{\partial \pi_j}{\partial a_k^\pi}
$$

この第1項について、先ほどの\(\gamma\)を使って

$$
\frac{\partial E_n}{\partial \pi_j}=-\frac{\mathcal{N}_{n j}}{\sum_{l=1}^K \pi_l \mathcal{N}_{n l}}=-\frac{\gamma_{n j}}{\pi_j}
$$

そして第二項について

$$
\begin{aligned}
\frac{\partial \pi_j}{\partial a_k^\pi} & =\frac{\partial}{\partial a_k^\pi}\left(\frac{e^{a_j^\pi}}{\sum_{l=1}^K e^{a_l^\pi}}\right) \\
& =\pi_j\left(\delta_{k j}-\pi_k\right)
\end{aligned}
$$

よって、この二式を合わせると

$$
\begin{aligned}
\frac{\partial E_n}{\partial a_k^\pi} & =\sum_{j=1}^K\left(-\frac{\gamma_{n j}}{\pi_j}\right) \pi_j\left(\delta_{k j}-\pi_k\right) \\
& =\sum_{j=1}^K \gamma_{n j}\left(\pi_k-\delta_{k j}\right) \\
& =-\gamma_{n k}+\sum_{j=1}^K \gamma_{n j} \pi_k \\
& =\pi_k-\gamma_{n k}\left(\text{∵}\sum_{j=1}^K \gamma_{n j}=1\right)
\end{aligned}
$$

この式は混合係数\(\pi_k\)を少し動かしたときに、誤差関数がどれくらい変わるかを示す。\(\gamma_k\)は事後分布であり、実際のデータがどれだけ\(k\)番目のガウス分布に当てはまるかを示している。この微分が正であれば、混合係数\(\pi_k\)を減らす方向に更新すると誤差が減少することを示している。

 

各要素の平均\(\mu_{kl}\)に対する微分

$$
\frac{\partial E_n}{\partial a_{k l}^\mu}=\gamma_k\left\{\frac{\mu_{k l}-t_l}{\sigma_k^2}\right\}
$$

証明

$$
a_{k l}^\mu=\mu_{k l}
$$

より

$$
\frac{\partial E_n}{\partial a_{k l}^\mu}=\frac{\partial E_n}{\partial \mu_{k l}}
$$

が得られる

$$
\begin{gathered}
\partial E_n=-\sum_{n=1}^N \ln \left(\sum_{k=1} \pi_k \mathcal{N}_{n k}\right) \\
\gamma_{n k}=\frac{\pi_k \mathcal{N}_{n k}}{\sum_{l=1}^K \pi_l \mathcal{N}_{n l}}
\end{gathered}
$$

これらとガウス分布の式を併せて

$$
\begin{aligned}
\frac{\partial E_n}{\partial \mu_{k l}} & =-\frac{\pi_k}{\sum_{k=1} \pi_k \mathcal{N}_{n k}} \cdot \mathcal{N}_{n k} \cdot \frac{t_{n l}-\mu_{k l}}{\sigma^2} \\
& =\gamma_{n k} \frac{\mu_{k l}-t_{n l}}{\sigma_k^2}
\end{aligned}
$$

この式も同様に、平均\(\mu_{kl}\)を少し動かしたときに誤差関数がどれくらい変わるかを示す。\(\gamma_k\)はその成分がデータにどれくらいフィットしているかを示し、\(\frac{\mu_{kl}-t_l}{\sigma_k^2}\)は平均と目標値がどれだけ離れているかを示している。

各要素の分散\(\sigma_k\)に対する微分

$$
\frac{\partial E_n}{\partial a_k^\sigma}=-\gamma_k\left\{\frac{\left\|\mathbf{t}-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^3}-\frac{1}{\sigma_k}\right\}
$$

証明

微分のchain ruleより

$$
\frac{\partial E_n}{\partial a_k^\sigma}=\frac{\partial E_n}{\partial \sigma_k} \frac{\partial \sigma_k}{\partial a_k^\sigma}
$$

第二項について

$$
\sigma_k=\exp \left(a_k^\sigma\right)
$$

より

$$
\frac{\partial \sigma_k}{\partial a_k^\sigma}=\exp \left(a_k^\sigma\right)=\sigma_k
$$

ガウス分布の定義より

$$
\begin{aligned}
\mathcal{N}_{n k} & =\frac{1}{2 \pi^{D / 2}} \frac{1}{\left|\sigma_{k^2} I\right|} \exp \left(-\frac{1}{2}\left(\mathbf{t}_n-\boldsymbol{\mu}_k\right)^T \frac{1}{\sigma_k^2}\left(\mathbf{t}_n-\boldsymbol{\mu}_k\right)\right) \\
& =\left(\frac{1}{2 \pi \sigma_k^2}\right)^{\frac{D}{2}} \exp \left(-\frac{1}{2}\left(\mathbf{t}_n-\boldsymbol{\mu}_k\right)^T \frac{1}{\sigma_k^2}\left(\mathbf{t}_n-\boldsymbol{\mu}_k\right)\right) \\
& =\left(\frac{1}{2 \pi \sigma_k^2}\right)^{\frac{D}{2}} \exp \left(-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^2}\right)
\end{aligned}
$$

第1項について変形したガウス分布の式を用いて

$$
\begin{aligned}
\frac{\partial E_n}{\partial \sigma_k} & =\frac{\pi_k}{-\sum_{k=1}^K \pi_k \mathcal{N}_{n k}}\left(\frac{1}{2 \pi}\right)^{\frac{D}{2}}\\
& \qquad \times \left(-\frac{L}{\sigma^{L+1}} \exp \left(-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^2}\right) +\frac{1}{\sigma_k^2} \exp \left(-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^2}\right.\right)\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^3}). \\
& =\frac{\mathcal{N}_{n k} \pi_k}{-\sum_{k=1}^K \pi_k \mathcal{N}_{n k}}\left(-\frac{L}{\sigma_k}+\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^3}\right) \\
& =\gamma_{n k}\left(\frac{L}{\sigma_k}-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^3}\right)
\end{aligned}
$$

最後に求めた第1項と第2項に掛け合わせて

$$
\begin{aligned}
\frac{\partial E_n}{\partial a_k^\sigma} & =\frac{\partial E_n}{\partial \sigma_k} \frac{\partial \sigma_k}{\partial a_k^\sigma} \\
& =\gamma_{n k}\left(\frac{L}{\sigma_k}-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^3}\right) \cdot \sigma_k \\
& =\gamma_{n k}\left(L-\frac{\left\|\mathbf{t}_n-\boldsymbol{\mu}_k\right\|^2}{\sigma_k^2}\right)
\end{aligned}
$$

よって示せた□

この式は、分散\(\sigma_k\)が誤差に与える影響を示す。ここでも\(\gamma_k\)が登場し、それがかかる項が分散の誤差にどれだけ寄与するかを決める。

総じて、これらの微分はニューラルネットワークの逆伝播で用いられ、各パラメータの更新に役立つ。このようにして、混合密度ネットワークは複雑なデータ分布を効率的に学習する。

 

混合密度ネットワークを先の人工的なデータで使ってみる

混合密度ネットワークをいったん訓練すれば、与えられた任意の入力ベクトルの値に対する目標データの条件付き密度関数が予測できる。

先ほどの最小二乗法では貧弱になったフィッティングに対して混合密度ネットワークを使用してみる。

 

混合密度ネットワークを、先ほどの人工的に生成されたデータで訓練した。ここでは混合数3の混合ガウス分布を使用することにする

3つのカーネル関数に対して混合係数\(\pi_k(x)\)をxの関数としてプロットしたもの。モデルには3つのガウス関数を要素とし、5つのtanhシグモイド隠れユニット、そして9つの出力ユニットを持つ二層パーセプトロンを用いている(9つの出力ユニットは、ガウス要素の3つの平均と3つの分散、3つの混合係数に対応している)

 

平均と分散

混合密度ネットワークをいったん訓練すれば、与えられた任意の入力ベクトルの値に対する目標データの条件付き密度関数が予測できる。

この密度関数を使えばいろいろな応用例において有用なもっと特定された量を計算できる。

混合密度ネットワークの平均

混合ネットワークの訓練が済み、出力されたパラメータを用いて、有用な量を計算する運びになるのだが。ここで一番単純な例は平均であり

$$
\mathbb{E}[\mathbf{t} \mid \mathbf{x}]=\int \mathbf{t} p(\mathbf{t} \mid \mathbf{x}) \mathrm{d} \mathbf{t}=\sum_{k=1}^K \pi_k(\mathbf{x}) \boldsymbol{\mu}_k(\mathbf{x})
$$

この式は見ての通り、混合係数に応じて、各ガウス分布の平均を足し合わせる。

最小二乗で訓練した標準的なネットワークは条件付き平均を近似しているので、従来の最小二乗の結果は混合密度ネットワークの独別な場合である。

ということは、先ほどから述べているように、多峰性を持つ分布の場合には条件付き平均ではあまり意味がない。考えてみれば当たり前で解が例えば2つある場合、いずれかを選ぶ必要があるが、この2つの解の平均は解にはならない。

つまりダメ;;

 

混合密度ネットワークの分散

では分散はどうなるのかというと

$$
\begin{aligned}
s^2(\mathbf{x}) & =\mathbb{E}\left[\|\mathbf{t}-\mathbb{E}[\mathbf{t} \mid \mathbf{x}]\|^2 \mid \mathbf{x}\right] \\
& =\sum_{k=1}^K \pi_k(\mathbf{x})\left\{L\sigma_k^2(\mathbf{x})+\left\|\boldsymbol{\mu}_k(\mathbf{x})-\sum_{l=1}^K \pi_l(\mathbf{x}) \boldsymbol{\mu}_l(\mathbf{x})\right\|^2\right\}
\end{aligned}
$$

証明

$$
p(\mathbf{t} \mid \mathbf{x})=\sum_{k=1}^K \pi_k(\mathbf{x}) \mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k(\mathbf{x}), \sigma_k^2(\mathbf{x}) \mathbf{I}\right)
$$

この式から計算を始める。まずは平均を\(\mathbb{E}[\mathbf{t} \mid \mathbf{x}]\)を計算すると

$$
\begin{aligned}
\mathbb{E}[\mathbf{t} \mid \mathbf{x}] & =\int \mathbf{t} p(\mathbf{t} \mid \mathbf{x}) d \mathbf{t} \\
& =\int \mathbf{t} \sum_{k=1}^K \pi_k \mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k, \sigma_k^2 \mathbf{I}\right) d \mathbf{t} \\
& =\sum_{k=1}^K \pi_k \int \mathbf{t} \mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k, \sigma_k^2 \mathbf{I}\right) d \mathbf{t} \\
& =\sum_{k=1}^K \pi_k \boldsymbol{\mu}_k
\end{aligned}
$$

となる。次に分散は\(s^2(x)=\mathbb{E}\left[\mathbf{t}^2 \mid \mathbf{x}\right]-\{\mathbb{E}[\mathbf{t} \mid \mathbf{x}]\}^2\)で求められるので、\(\mathbb{E}\left[\mathbf{t}^2 \mid \mathbf{x}\right]\)を計算すると

$$
\begin{aligned}
\mathbb{E}\left[\mathbf{t}^2 \mid \mathbf{x}\right] & =\mathbb{E}\left[\mathbf{t}^{\mathrm{T}} \mathbf{t} \mid \mathbf{x}\right] \\
& =\mathbb{E}\left[\operatorname{Tr}\left[\mathbf{t}^{\mathrm{T}} \mathbf{t}\right] \mid \mathbf{x}\right] \\
& =\mathbb{E}\left[\operatorname{Tr}\left[\mathbf{t} \mathbf{t}^{\mathrm{T}}\right] \mid \mathbf{x}\right] \\
& =\operatorname{Tr}\left[\int \mathbf{t t}^{\mathrm{T}} \sum_{k=1}^K \pi_k \mathcal{N}\left(\mathbf{t} \mid \boldsymbol{\mu}_k, \sigma_k^2 \mathbf{I}\right) d \mathbf{t}\right] \\
& =\sum_{k=1}^K \pi_k \operatorname{Tr}\left[\boldsymbol{\mu}_k \boldsymbol{\mu}_k^{\mathrm{T}}+\sigma_k^2 \mathbf{I}\right] \\
& =\sum_{k=1}^K \pi_k\left(\left\|\boldsymbol{\mu}_k\right\|^2+L \sigma_k^2\right)
\end{aligned}
$$

ここでLは\(\mathbf{t}\)の次元数である。途中の式変形では

$$
\mathbb{E}\left[\mathbf{x} \mathbf{x}^{\mathrm{T}}\right]=\boldsymbol{\mu} \boldsymbol{\mu}^{\mathrm{T}}+\boldsymbol{\Sigma}
$$

これらを用いて計算すると

$$
\begin{aligned}
s^2(\mathbf{x}) & =\sum_{k=1}^K \pi_k\left(L \sigma_k^2+\left\|\boldsymbol{\mu}_k\right\|^2\right)-\left\|\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2 \\
& =L \sum_{k=1}^K \pi_k \sigma_k^2+\sum_{k=1}^K \pi_k\left\|\boldsymbol{\mu}_k\right\|^2-\left\|\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2 \\
& =L \sum_{k=1}^K \pi_k \sigma_k^2+\sum_{k=1}^K \pi_k\left\|\boldsymbol{\mu}_k\right\|^2-2 \times\left\|\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2+1 \times\left\|\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2 \\
& =L \sum_{k=1}^K \pi_k \sigma_k^2+\sum_{k=1}^K \pi_k\left\|\boldsymbol{\mu}_k\right\|^2-2\left(\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right)\left(\sum_{k=1}^K \pi_k \boldsymbol{\mu}_k\right)+\left(\sum_{k=1}^K \pi_k\right) \| \sum_{l=1}^K \pi_{l}\mathbf{\mu}_l \|^2 \\
& =L \sum_{k=1}^K \pi_k \sigma_k^2+\sum_{k=1}^K \pi_k\left\|\boldsymbol{\mu}_k\right\|^2-2\left(\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right)\left(\sum_{k=1}^K \pi_k \boldsymbol{\mu}_k\right)+\sum_{k=1}^K \pi_k\left\|\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2 \\
& =L \sum_{k=1}^K \pi_k \sigma_k^2+\sum_{k=1}^K \pi_k\left\|\boldsymbol{\mu}_k-\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2 \\
& =\sum_{k=1}^K \pi_k\left(L \sigma_k^2+\left\|\boldsymbol{\mu}_k-\sum_{l=1}^K \pi_l \boldsymbol{\mu}_l\right\|^2\right) \\
&
\end{aligned}
$$

一つのガウス分布内での分散と一つのガウス分布の平均と全体の平均との分散を足して、混合係数をかけてそれの総和をとったものである。

この値は入力によって変化するため、最小二乗法の結果より一般的である

 

 

 

 

 

 

 

 

 

-数学, 機械学習/AI