最適輸送問題とシンクホーンアルゴリズム
最適輸送問題とは、物理学や経済学などの分野で長い歴史を持つ数理的な問題です。この問題は、ある場所に存在する物資を別の場所に運ぶ際の最適な方法を探求します。具体的には、物資を運ぶためのコストを最小化する方法を見つけることを目的としています。
-
ワッサースタイン(Wasserstein)距離とは?距離の公理を満たすことの証明まで
GANなどの生成モデルで使われるようになりたびたび聞くことも増えたワッサースタイン(Wasserstein)距離というものを解説します。 ワッサースタイン距離というのは一言でいうと、確率分布間を測る距 ...
続きを見る
この記事でも定式化していますが、最適輸送は重み付き点群を輸送コストをもとに比較するツールです。
点群Aと点群Bの違い、距離を求めるものです。これの定式化は以下です。
最適輸送の定式化
- 入力:比較する点群\(\alpha=\left\{x_1, \cdots, x_n\right\}, \beta=\left\{y_1, \cdots, y_m\right\} \subset \mathcal{X}\)、各点の距離を表す関数\(C: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\)
- 出力:点群の距離\(\mathrm{OT}(\alpha, \beta, C) \in \mathbb{R}\)
- 最適輸送距離を以下の最適化問題の最適値と定義する
$$
\begin{array}{lc}
\text{minimize}_{P \in \mathbb{R}^{n \times n}} & \sum_{i=1}^n \sum_{j=1}^n C_{i j} P_{i j} \\
\text { s.t. } & P_{i j} \geq 0 \quad \forall i\in [n], j \in [m] \\
& \sum_{j=1}^n P_{i j}=\boldsymbol{a}_i \quad \forall i \in [n] \\
& \sum_{i=1}^n P_{i j}=\boldsymbol{b}_j \quad \forall j \in [m]
\end{array}
$$
しかしこれにはいくつか短所があります。
線形計画の最適輸送の定式化の問題点
- 線形計画ソルバーの多くはブラックボックスで、計算時間の見積もりが難しく、ほかの手法と柔軟に組み合わせることができない
- 計算量が大きい、最悪入力サイズの三乗に比例する
- 入力のa,b,Cについて滑らかでない
- GPU計算ができない
中でも、問題なのは滑らかでないことです。
最適輸送コストはモデルの損失関数としても頻繁に使用され、このとき損失の勾配が計算できることが要請されます。
入力についての勾配を求めて最適化をしたい
しかし、滑らかでないのでこの定式化では微分ができず勾配を求めることができません。
でも大丈夫、ちょっと変えてあげることでなんとこれらの問題が一気に解決します。
【定義】エントロピー正則化つきの最適輸送問題
$$
\begin{aligned}
& \underset{\boldsymbol{P} \in \mathbb{R}^{n \times m}}{\operatorname{minimize}} \sum_{i=1}^n \sum_{j=1}^m \boldsymbol{C}_{i j} \boldsymbol{P}_{i j}-\varepsilon H(\boldsymbol{P}) \\
& \text { subject to } \quad P_{i j} \geq 0 \quad(\forall i \in[n], \forall j \in[m]) \\
& \sum_{j=1}^m P_{i j}=a_i \quad(\forall i \in[n]) \\
& \sum_{i=1}^n P_{i j}=b_j \quad(\forall j \in[m]) \\
&
\end{aligned}
$$
ここで、行列\(\boldsymbol{P} \in \mathbb{R}_{\geq 0}^{n \times m}\)について、エントロピー関数を
$$
H(\boldsymbol{P}) \stackrel{\text { def }}{=}-\sum_{i=1}^n \sum_{j=1}^m \boldsymbol{P}_{i j}\left(\left(\log \boldsymbol{P}_{i j}\right)-1\right)
$$
と定義します。ただし、\(0 \log 0=0\)であると定義します。
エントロピー正則化項を加えたこの問題の目的関数の第一項が\(\boldsymbol{P}\)について線形、第二項のエントロピー関数が\(\boldsymbol{P}\)について強凹関数なので、これらの差である目的関数は強凸関数になります。
また、制約条件は線形関数のみなので、実行可能領域は凸集合になります。この問題のように、目的関数が凸関数で、実行可能領域が凸集合である最適化問題を凸計画といい、凸計画に対する効率的なアルゴリズムは数多く知られています。
特に目的関数の強凸性により、より強力な最適化アルゴリズムを適用できるのが、この定式化の利点です。この最適化を解くための単純かつ高速なアルゴリズムであるシンクホーンアルゴリズムを紹介します。このシンクホーンアルゴリズムはGPUを用いた並列化を行うことができるほか、エントロピー正則化を加えた問題においては最適値\(\mathrm{OT}_{\varepsilon}\)が各種パラメータについて微分可能となります。
シンクホーンアルゴリズムはエントロピー正則化つき問題を効率よく解くアルゴリズムです。
双対問題の導出
まず,エントロピー正則化つき問題の目的関数にラグランジュ乗数 \(f \in\) \(\mathbb{R}^n, g \in \mathbb{R}^m\) を導入して,以下のラグランジュ緩和問題を考えます。すなわち
$$
\begin{aligned}
L(\boldsymbol{P}, \boldsymbol{f}, \boldsymbol{g}) \stackrel{\text { def }}{=}\left(\sum_{i=1}^n \sum_{j=1}^m \boldsymbol{P}_{i j} \boldsymbol{C}_{i j}\right)-\varepsilon H(\boldsymbol{P}) \\
\quad+\sum_{i=1}^n \boldsymbol{f}_i\left(\boldsymbol{a}_i-\sum_{j=1}^m \boldsymbol{P}_{i j}\right)+\sum_{j=1}^m \boldsymbol{g}_j\left(\boldsymbol{b}_j-\sum_{i=1}^n \boldsymbol{P}_{i j}\right)
\end{aligned}
$$
とし、
$$
\underset{\boldsymbol{P} \in \mathbb{R}_{\geq 0}^{n \times m}}{\operatorname{minimize}} L(\boldsymbol{P}, \boldsymbol{f}, \boldsymbol{g})
$$
を考えます。ラグランジュ関数\(L\)の\(\boldsymbol{P}\)についての偏微分は
$$
\frac{\partial L}{\partial \boldsymbol{P}_{i j}}=\boldsymbol{C}_{i j}+\boldsymbol{\varepsilon} \log \boldsymbol{P}_{i j}-\boldsymbol{f}_{\boldsymbol{i}}-\boldsymbol{g}_j
$$
です。これをゼロとおくと、緩和問題の解は
緩和問題の解
$$
\boldsymbol{P}_{i j}^*=\exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
と求まります。このとき目的関数値は
$$
\begin{aligned}
& L\left(\boldsymbol{P}^*, \boldsymbol{f}, \boldsymbol{g}\right) \\
& =\sum_{i=1}^n \sum_{j=1}^m \boldsymbol{P}_{i j}^*\left(\boldsymbol{C}_{i j}-\boldsymbol{f}_i-\boldsymbol{g}_j+\varepsilon\left(\log \boldsymbol{P}_{i j}^*-1\right)\right)+\sum_{i=1}^n \boldsymbol{f}_i \boldsymbol{a}_i+\sum_{j=1}^m \boldsymbol{g}_j \boldsymbol{b}_j \\
& =\sum_{i=1}^n \boldsymbol{f}_i \boldsymbol{a}_i+\sum_{i=1}^m g_j \boldsymbol{b}_j-\varepsilon \sum_{i=1}^n \sum_{i=1}^m \exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
\end{aligned}
$$
となります。ただし、最後の等式には
$$
\frac{\partial L}{\partial \boldsymbol{P}_{i j}}=\boldsymbol{C}_{i j}+\varepsilon \log \boldsymbol{P}_{i j}-\boldsymbol{f}_{\boldsymbol{i}}-\boldsymbol{g}_j=0
$$
であることと\(\boldsymbol{P}_{i j}^*=\exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)\)を用いました。
よって、ラグランジュ双対問題は
$$
\underset{\boldsymbol{f} \in \mathbb{R}^n, \boldsymbol{g} \in \mathbb{R}^m}{\operatorname{maximize}} \sum_{i=1}^n \boldsymbol{f}_i a_i+\sum_{j=1}^m \boldsymbol{g}_j b_j-\varepsilon \sum_{i=1}^n \sum_{j=1}^m \exp \left(\left(f_i+g_j-C_{i j}\right) / \varepsilon\right)
$$
となります。一般に、凸計画問題においてすべての不等号が厳密に成り立つ実行可能解が存在するときスレイターの条件が成り立つといい、このとき主問題と双対問題の最適値が一致する強双対性が成り立つことが知られています。
エントロピー正則化付きの最適輸送問題は凸計画であって、スレイター条件を満たすので、強双対性が成り立ちます。
すなわち、エントロピー正則化付き最適輸送問題とラングランジュ双対問題の最適値は一致します。ラグランジュ双対問題が解けると直ちに主問題の最適解も定まるため、以降しばらくは双対問題を解くことに集中します。
双対問題は制約なし最大化問題であって、目的関数が狭義凹なので、最適化の観点からは扱いやすい問題になります。
対数領域シンクホーンアルゴリズム
双対問題の目的関数を
$$
L_D(\boldsymbol{f}, \boldsymbol{g}) \stackrel{\text { def }}{=} \sum_{i=1}^n \boldsymbol{f}_i \boldsymbol{a}_i+\sum_{j=1}^m \boldsymbol{g}_j \boldsymbol{b}_{\boldsymbol{j}}-\varepsilon \sum_{i=1}^n \sum_{j=1}^m \exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
とおくと、\(L_D\)の\(\boldsymbol{f}, \boldsymbol{g}\) についての勾配は
$$
\begin{aligned}
& \frac{\partial L_D}{\partial \boldsymbol{f}_i}=\boldsymbol{a}_i-\sum_{j=1}^m \exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right) \\
& \frac{\partial L_D}{\partial \boldsymbol{g}_j}=\boldsymbol{b}_j-\sum_{i=1}^n \exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
\end{aligned}
$$
となります。注目すべきは、\(\boldsymbol{f}\)についての勾配についての式には\(\boldsymbol{f}_j(j \neq i)\)が出現しないことです。すなわち\(\boldsymbol{g}\)を固定すると、最適な\(\boldsymbol{f}\)は成分ごと独立に計算でき、その時の値は
$$
\frac{\partial L_D}{\partial \boldsymbol{f}_i}=\boldsymbol{a}_i-\sum_{j=1}^m \exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)=0
$$
と置くと、
$$
\boldsymbol{f}^*(\boldsymbol{g})_i=\varepsilon \log \boldsymbol{a}_i-\varepsilon \log \sum_{j=1}^m \exp \left(\left(\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
となります。\(\boldsymbol{g}\)についても同様に、\(\boldsymbol{f}\)を固定したときの最適な\(\boldsymbol{g}\)は
$$
\boldsymbol{g}^*(\boldsymbol{f})_j=\varepsilon \log \boldsymbol{b}_j-\varepsilon \log \sum_{i=1}^n \exp \left(\left(\boldsymbol{f}_i-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
となります。この式より、\(\boldsymbol{f}\)を固定して\(\boldsymbol{g}\)を最大化し、得られた\(\boldsymbol{g}\)を固定して\(\boldsymbol{f}\)を最大化することを繰り返すブロック座標上昇法が自然に導けます。これが対数領域シンクホーンアルゴリズムです。疑似コードは以下です
ここで、\(C_{i,:}, C_{:, j}\)はそれぞれ、行列\(C\)の\(i\)行目と\(j\)列目を表すベクトルを示し、ベクトルに対する\(\log\) 関数と \(\exp\) 関数は成分ごとに適用することにします。
またここでは、確率ベクトルの成分はすべて正であると暗黙的に仮定しています。成分が0である場合には、\(log0\)の項が登場してしまうため、場合分けを行うか微少量を加えて対処することになります。
対数領域シンクホーンアルゴリズムにより双対解が得られた後、主問題の解(輸送行列)を得たければ、式
$$
\boldsymbol{P}_{i j}^*=\exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
を用いて変換することになります。強双対性より、\((\boldsymbol{f},\boldsymbol{g})\)が最適解であれば得られる輸送行列も主問題の最適解になります。ただし、\((\boldsymbol{f},\boldsymbol{g})\)が最適解でなければ、先の式で得られる\(P\)が実行可能解である保証すらありません。最適でない双対解についての実行可能な主解を得る方法についてはのちに議論します。
シンクホーンアルゴリズムの導出
ここまで対数領域シンクホーンアルゴリズムを導出しました。ここからは、変数変換により、さらにシンプルなアルゴリズムを導出します。通常シンクホーンアルゴリズムというと本節で導出する指数領域でのシンクホーンアルゴリズムのことを指します。
$$
\boldsymbol{f}^*(\boldsymbol{g})_i=\varepsilon \log \boldsymbol{a}_i-\varepsilon \log \sum_{j=1}^m \exp \left(\left(\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
$$
\boldsymbol{g}^*(\boldsymbol{f})_j=\varepsilon \log b_j-\varepsilon \log \sum_{i=1}^n \exp \left(\left(\boldsymbol{f}_i-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
にはlogやexp関数が多く登場します。そこで、決定変数を\(\boldsymbol{u} \stackrel{\text { def }}{=} \exp (\boldsymbol{f} / \varepsilon), \boldsymbol{v} \stackrel{\text { def }}{=} \exp (\boldsymbol{g} / \varepsilon)\)と変数変換することで、これらの式は
$$
\begin{aligned}
\boldsymbol{u}^*(\boldsymbol{v})_i & =\frac{a_i}{\sum_{j=1}^m \boldsymbol{v}_j \exp \left(-\boldsymbol{C}_{i j} / \varepsilon\right)} \\
\boldsymbol{v}^*(\boldsymbol{u})_j & =\frac{b_j}{\sum_{i=1}^n \boldsymbol{u}_i \exp \left(-\boldsymbol{C}_{i j} / \varepsilon\right)}
\end{aligned}
$$
とより簡潔に表せます。また、ギブスカーネル\(\boldsymbol{K}_{i j}=\exp \left(-\boldsymbol{C}_{i j} / \varepsilon\right)\)を用いると、上式はさらに簡潔に
$$
\begin{aligned}
\boldsymbol{u}^*(\boldsymbol{v}) & =\frac{\boldsymbol{a}}{\boldsymbol{K} \boldsymbol{v}} \\
\boldsymbol{v}^*(\boldsymbol{u}) & =\frac{\boldsymbol{b}}{\boldsymbol{K}^{\top} \boldsymbol{u}}
\end{aligned}
$$
と表すことが出来ます。ただし、ベクトルの割り算は成分ごとの割り算とします。この二式を用いた交互最適化がシンクホーンアルゴリズムです。このアルゴリズムはCuturiにより機械学習コミュニティに導入され、活発に利用・研究されています。
初期ベクトルを\(\frac{1}{\|\boldsymbol{K}\|_1} \mathbb{1}_m\)としているのはのちの理論解析のためであり、実用上は \(\boldsymbol{v}^{(0)} \leftarrow \mathbb{1}_m\) などと適当に初期化することも多いです。 \(\boldsymbol{f}\) と \(\boldsymbol{u}\) は 一対一に対応し, \(\boldsymbol{g}\) と \(\boldsymbol{v}\) は一対一に対応するのでシンクホーンアルゴリ ズムは対数領域版と動作は同じとなります。 また, 指数領域において最適解 \(\left(\boldsymbol{u}^*, \boldsymbol{v}^*\right)\) が求まると, 変数変換を逆に行い, 対数領域における最適解 \(\left(\boldsymbol{f}^*, \boldsymbol{g}^*\right)=\left(\varepsilon \log \boldsymbol{u}^*, \varepsilon \log \boldsymbol{v}^*\right)\) を得ることができます。
緩和問題の解よりここからさらに主問題の最適解も得ることができます。ただし,完全に収束するまでは変換によって必ずしも実行可能解が得られるとは限らない点は対数領域の場合と同じく注意が必要です。
シンクホーンアルゴリズムの強みの一つが, 非常にシンプルなことです。 アルゴリズムが行列積と割り算だけから構成されるので,他の手法に組み込 むことも容易です. 1 回の最適化反復が \(O(n m)\) 時間と高速なことも強みの一つです.また,GPUによる並列計算も容易です。
指数領域の場合の注意点としては、数学的には対数領域版と動作がまったく同じですが,値が指数的に大きく・小さくなる場合があり,そのような場合には計算機上では数値的に不安定になる場合があることです。 そのような場合, 正則化係数を大きくとるなどの対処が必要となります。 安定な大きさ の \(\varepsilon\) における結果に満足がいかない場合は, \(\varepsilon\) を小さくして対数領域でのシンクホーンアルゴリズムを数値安定な log sum exp 関数とともに用いるか, \(\varepsilon\) の大きさはそのままに 3.7 節で述べるシンクホーンダイバージェンスを用いるとよいでしょう.
シンクホーンアルゴリズムにより得た近似解を主問題の解に変換する
シンクホーンアルゴリズムにより、双対問題の解と目的問題の値を得ることができますが、主問題の解、すなわち輸送行列を得たい場合もしばしば存在あります。無限回の反復のあと、双対解は厳密解となり、
$$
\boldsymbol{P}_{i j}^*=\exp \left(\left(\boldsymbol{f}_i+\boldsymbol{g}_j-\boldsymbol{C}_{i j}\right) / \varepsilon\right)
$$
を用いて主問題の解を得ることができますが、現実的には有限回で打ちとめることになるため、そのままでは主問題の解を得ることはできません。本節では、厳密ではない双対解をさきの式により変換した後に、輸送行列にうまく丸める方法を紹介します。疑似コードを示します。
直感的には1~4行目において、\(\boldsymbol{A}\)の行和・列和と\(\boldsymbol{a},\boldsymbol{b}\)の比を各行・列とかけ合わせることで補正します。ただし、比が大きい箇所については誤差の増大を防ぐためにそのままとします。つづいて、5~7行目において、残った違反分を加法的に修正します。このアルゴリズムの正当性は以下の定理により示されます。
【定理】解の丸めによる誤差
任意の正行列を予想行列に丸めるアルゴリズムの出力\(\boldsymbol{P}\)は輸送多面体\(\mathcal{U}(a, b)\)に含まれ、入力行列との誤差は
$$
\|\boldsymbol{P}-\boldsymbol{A}\|_1 \leq 2\left(\left\|\boldsymbol{A} \mathbb{1}_m-\boldsymbol{a}\right\|_1+\left\|\boldsymbol{A}^{\top} \mathbb{1}_n-\boldsymbol{b}\right\|_1\right)
$$
で抑えられる
ゆえに実行可能とは限らない主問題の解\(\boldsymbol{A}\)があったとき、\(\boldsymbol{A}\)の目的関数値が小さく、かつ\(\boldsymbol{A}\)の行和・列和が\(\boldsymbol{a},\boldsymbol{b}\)に十分近ければ、アルゴリズムによって実行可能かつ目的関数値が小さい解を計算できます。
参考
https://speakerdeck.com/joisino