数学 時系列分析 線形代数 自然言語処理

Multi-Head AttentionとScaled Dot-Product Attentionの全て:Transformerの核心を徹底解説

2023年12月9日

この記事では、Transformerの中心的な役割を果たすMulti-Head Attentionについて解説する。

 

Transformerのほかの機構の詳細な解説はせず、完全にMulti-Head Attention特化の解説となるので、ほかの機構や全体観を掴みたい方はこちらの記事をご一読いただきたい。

Transformerとは?世界を変えた深層学習モデルの仕組みをわかりやすく徹底解説

近年のAI技術の急激な発展には「Transformer」という深層学習モデルの存在が大きく関わっている。 この記事では、そのTrasformerについてその仕組みとそれがなぜ組み込まれているかを画像を ...

続きを見る

この記事はまずはScaled Dot-Product AttentionというMulti-Head Attentionの中で使われている核心部分についてこれでもかと詳しく解説したのちに、本題のMulti-Head Attentionについて解説し、その後Transformerのデコーダー部分で使われる二つの注意機構について解説する。

これ以上ないくらい丁寧に解説をしたつもりである。是非ゆるりと読んでみてほしい、Multi-Head Attentionを完璧と言って差し支えないレベルで理解できると思う。なお、この記事では直感的なわかりやすさを重視し画像をふんだんに使用している。要するに渾身極まる記事である。

Transformerの解説記事とこの記事を併せて読めば、Transformerのすべてが理解できると思う。

 

必要な知識は高校数学まで。(逆に昨今AI界を席巻するTransformerの核心部分が高校数学まででできているというのは面白い)

 

Transformerとは

Attentionについて説明する前に、軽くTransformerについて触れておこう。

Transformerとは、自然言語処理(Natural Language Processing, NLP)の分野において、2017年に「Attention Is All You Need」(Vaswani et al, 2017)という論文で初めて提案されたモデルである。従来のリカレントニューラルネットワーク(Recurrent Neural Network, RNN)やその派生形であるLSTM(Long Short-Term Memory)が持つ時系列データを順番に処理する特性とは異なり、トランスフォーマーは「Attention機構」を用いることで、入力されたデータの全ての部分を一度に処理することが可能である。これにより、文中の遠距離にある単語間の関連性を捉えたり、計算速度を向上させるなどの利点を持つ。

Transformerのコアとなるのが「Scaled Dot-Product Attention機構」であり、これは入力されたシーケンス内の各位置の単語が、他の位置の単語にどれだけ影響を受けるか、またはどれだけの関連があるかを計算する仕組みである。このScale Dot-Product Attentionを拡張したものが「Multi-Head Attention」であり、異なる表現空間でAttention機構を複数回行い、異なる視点から情報を集約することで、より豊かな文脈表現を可能にしている。

このモデルは、翻訳や文章要約、質問応答システムなど、多くのNLPタスクにおいて顕著な成果を上げており、現在では多くの最先端技術の基盤となっている。また、トランスフォーマーはNLPに限らず、画像処理音声認識など、他の機械学習の分野においても応用されつつある。それにより、トランスフォーマーは機械学習モデルの新たなスタンダードとして広く認知されているのである。

詳しくはこちらの記事を読んでほしい。

Transformerとは?世界を変えた深層学習モデルの仕組みをわかりやすく徹底解説

近年のAI技術の急激な発展には「Transformer」という深層学習モデルの存在が大きく関わっている。 この記事では、そのTrasformerについてその仕組みとそれがなぜ組み込まれているかを画像を ...

続きを見る

 

Multi-Head Attentionとは

Multi-Head Attentionは入力の直前に線形変換をする層を持つ「Scaled Dot-Product Attention」を複数並列に並べた配置をしていて、それを連結した行列を線形変換したものを出力としている。

このように、Attention機構を複数並列に行うことからMulti-Headと呼ばれ、これが多様な表現を可能にしているのである。

複雑な構造をしているが、その根本はScaled Dot-Product Attentionであり、それさえ理解できてしまえばMulti-Head Attentionを理解するのは容易なのである。

 

Scaled Dot-Product Attentionを理解しよう

Scaled Dot-Product Attention(スケール化内積注意)は、内積を活用してベクトル間の類似性に基づく変換を行う仕組みである。この機構は、Query(クエリ)、Key(キー)、Value(バリュー)という3つの行列の入力を用いる。このQ,K,Vはトークンの埋め込み列から得たベクトルの列である。

この機構は、埋め込み列を入力として受け取り、それらを相互に参照して、文脈の情報を加えた新しい埋め込み列を計算するのが目的。

さらに、このScaled Dot-Product Attentionの押さえておくべき重要な事項は入力の行列のサイズと出力の行列のサイズが一致しているということである。(厳密に言うと入力のVの行列のサイズと出力の行列のサイズが一致)

Scaled Dot-Product Attentionの計算

Scaled Dot-Product Attentionでは次の計算を行う。

$$
softmax\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{\mathrm{T}}}{\sqrt{\mathrm{d}_{\mathrm{k}}}}\right) \boldsymbol{V}
$$

Scaled-Dot Product Attentionの模式図

Scaled Dot-Product Attentionの基本的な構造は上図の通りである。いろいろやっているように見えるが結局\(softmax\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{\mathrm{T}}}{\sqrt{\mathrm{d}_{\mathrm{k}}}}\right) \boldsymbol{V}\)を計算するだけである。\(\mathrm{d}_{\mathrm{k}}\)はクエリとキーのベクトルの次元

上の図はこの図を段階的に表したに過ぎない。

Q,K,Vの作り方

そもそもQ,K,Vはどこから来たのか、まずはそこを確認しておこう。

 

位置符号を加えた「\(\text{シーケンス長}n\times d_{model} \)」の埋め込み行列\(X\)に対して、三つの\(d_{model} \times d_{model}\)次元の専用の変換行列を使って線形変換を行う。これらの線形変換行列は、トランスフォーマーモデルの学習パラメータの一部として扱われ、学習過程で最適化される。

計算方法は、埋め込み行列に対して、右からそれぞれの変換行列を掛けてあげることでQ,K,Vが得られる。

$$
\begin{aligned}
& Q=X W_Q \\
& K=X W_K \\
& V=X W_V
\end{aligned}
$$

なぜ、そもそもQ,K,Vに分解したのかはここからの計算の意味を理解するとわかるようになる。

 

まずは簡易な例で流れを把握しよう

 

イメージしやすくするための具体例として「マウスでクリックする」という文章を考えてみる。このような場合「マウス」の意味がパソコンで使用する入力機器のマウスの方でネズミの方じゃないということを捉えるには、ほかのトークン情報よりも「クリック」という単語の方が重要かもしれない。実際我々人間が判断するときもそうである。

さて、「マウスでクリック」という文章がトークナイザー(単語以下に分割するもの)によって「マウス/で/クリック」という風にトークンに分割されたとする。

そしてそのトークン列の位置符号を加算し終えた埋め込み列\(X\)が次のような行列になったとする。ここでは、簡単のため各トークンのベクトルの次元は4とする。つまり\(d_{model}=4\)で行列のサイズは\(3 \times 4\)となる

$$
X=\left(\begin{array}{llll}
1 & 0 & 1 & 2 \\
2 & 1 & 2 & 0 \\
0 & 0 & 1 & 1
\end{array}\right)
$$

これに変換行列を掛けて、Q,K,Vを得たとする。

 

行列積

$$
Q K^{\top}
$$

まずは、\(Q\)と転置した\(K\)の行列を掛け合わせる。これは(シーケンス長\(n\)×\(d_k\))×(\(d_k\)×シーケンス長\(n\))という行列積の形なので、出力される行列のサイズは(シーケンス長\(n\)×シーケンス長\(n\))である。

転置とは、行と列を入れ替える操作のことである。

スケール

$$
\frac{Q K^{\top}}{\sqrt{d_k}}
$$

行列のすべての要素を\(\sqrt{d_k}\)で割る。今回の場合\(d_k=4\)なので、すべての要素を2で割る。

Softmax関数

 

$$
softmax(\frac{Q K^{\top}}{\sqrt{d_k}})
$$

そして得た行列に対して、softmax関数を適用する。Softmax関数はベクトルを入力として、その要素を0から1の間に変換する関数である。

【定義】Softmax関数

\(n\)次元の実数ベクトル\(z = z_1,z_2, \cdots ,z_n\)に対し

$$
Softmax\left(z_i\right)=\frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}}
$$

ソフトマックス(Softmax)関数とは:定義、特性、実装方法まで徹底解説

本記事は、Softmax関数に関する理解を深めたい方々を対象にしています。 必要な前提知識は以下の通りです。 基本的な数学知識(特に確率論と指数関数) ニューラルネットワークの基本概念 Pythonプ ...

続きを見る

 

この操作で、行列の横方向の総和が1に調整される。

softmax関数を適用する方向

softmax関数の引数\(\frac{Q K^{\top}}{\sqrt{d_k}}\)は行列なので、softmaxを縦方向、横方向どちらに適用するか迷うところだが、Transformerにおいては横方向に適用する。この理由は後程自明となる。

行列積

$$
\operatorname{softmax}\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V
$$

そして最後に、Valueベクトルを並べた行列Vを掛けて、このScaled Dot-Product Attentionの計算はすべて完了である。

初心者でも理解できる!行列とベクトルの積の基礎と実践的応用

数学の世界において、行列とベクトルの積は基本的かつ強力な概念である。 この記事では、行列とベクトルの積がどのように計算され、どのような幾何学的意味を持つのかを段階的に解説する。 さらに、コンピュータグ ...

続きを見る

計算をまとめると

 

Scaled Dot-Product Attentionの計算処理が何をしているのか自体はわかったことと思う。特筆すべきは、入力の行列のサイズと出力の行列のサイズが一致していることである。Transforemerのブロックを縦に並べる性質上当たり前ではあるが、こうしないとAttention機構を通るたびに行列のサイズが変化してしまい、次の入力の次元が変わりまくってしまうので非常に扱いづらくなってしまう

 

さて、Scaled Dot-Product Attentionがどういう計算をしているのかは分かった。だが、先ほどから述べているようにこの機構は文脈の情報、つまりその文章におけるその単語同士の関係性を測ることができる。それは一体どのように行われていて、なぜ可能なのであろうか。

 

ここからは段階的に今の計算を振り返って一つ一つその秘密を紐解き、それぞれの計算をする理由、お気持ちを理解しよう。

Scaled Dot-Product Attentionの気持ちを理解する

トークン埋め込み行列の横方向のベクトルは順に入力されたトークン(単語)に対応する横ベクトルとなる。

それを分割して並べ、視覚的にそれぞれがどういう関係で計算フローが成り立っているのか見ることでこのScaled Dot-Product Attentionは理解しやすくなる。

さて、一つ一つ順に考えていく。まずはScaled Dot-Product Attentionに入力するQ,K,Vを得たいので

学習可能なパラメータ行列を使用した線形変換をすることでKeyベクトルとValueベクトルを作成する。

順番に処理を追っていきたいので、まずは一つ目の「マウス」のトークンから線形変換によって得たQueryベクトルを見る。

さて、Scaled Dot-Product Attentionの最初の工程、\(Q K^{\top}\)が行われている。これは最初の一行一列目の結果である。

これは一つ目のトークン「マウス」のKeyベクトルとQueryベクトルの内積を計算しているに等しい。内積の計算、すなわちその二つのベクトルの類似度を計算しているのである。

 

Attentionではなぜコサイン類似度ではなく内積を計算するのか

なぜコサイン類似度ではなく内積を計算するの?

ここである二つのベクトルが似ているとはどういうことか考えてみよう。

ベクトルの成分の大小が一致していること?、似たところがプラスで似たところがマイナスになっていることか?

これはまさに、ある二つのベクトルの向きが近いということに他ならない。つまり、ベクトルの向きが近いということは二つのベクトルのなす角\(\theta\)の値が小さいということである。

ここで内積の定義を思い出してみると

【定義】内積

$$
\begin{aligned}
& \boldsymbol{x}=\left(\begin{array}{c}
x_1 \\
\vdots \\
x_n
\end{array}\right), \boldsymbol{y}=\left(\begin{array}{c}
y_1 \\
\vdots \\
y_n
\end{array}\right) \in \mathbb{R}^n \\
& \boldsymbol{x} \cdot \boldsymbol{y}=\|\boldsymbol{x}\|\|\boldsymbol{y}\| \cos \theta
\end{aligned}
$$

つまり、\(\theta\)が小さいということは\(cos\theta\)が大きくなるのですなわち内積が大きいということになる。

よって内積が大きさがそのベクトルの類似度ということになる。

こういった自然言語処理の文脈でよく用いられるコサイン類似度の定義は、以下のようなものである。

【定義】コサイン類似度

$$
\begin{aligned}
& \boldsymbol{x}=\left(\begin{array}{c}
x_1 \\
\vdots \\
x_n
\end{array}\right), \boldsymbol{y}=\left(\begin{array}{c}
y_1 \\
\vdots \\
y_n
\end{array}\right) \in \mathbb{R}^n \\
& \cos (\boldsymbol{x}, \boldsymbol{y})=\frac{\langle\boldsymbol{x}, \boldsymbol{y}\rangle}{|\boldsymbol{x} \| \boldsymbol{y}|}
\end{aligned}
$$

コサイン類似度の大きな特徴として、ベクトルの方向のみを考慮し、大きさの影響を排除していることがあるが、これがよくない。

なぜならば、Transformerのエンコーダーブロックは縦に\(N\)回も連続して、そのたびに文脈の情報を付加している。

例えばベクトルが大きければ大きいほど、その単語が持つ情報量や意味の強度が大きいと解釈されることがある。さらには、その単語がほかの単語に比べて特に重要である場合、その単語を表すベクトルは相対的に大きくなることがある。

コサイン類似度はその情報を排除してしまうことになるので、ここでは様々な情報を含め「直接的に」類似しているかどうかを測れる内積が使われるのである。

コサイン類似度とは?ベクトルの内積から見る類似度

データ分析や機械学習の分野で広く用いられるコサイン類似度は、ベクトル間の類似性を測定する強力なツールである。   この記事では、コサイン類似度の基本原理から、高次元データでの振る舞い、そして ...

続きを見る

同様の計算を「マウス」のQueryベクトルに対して「で」と「クリック」のKeyベクトルとの内積を計算する。

これによって「マウス」のQueryベクトルの処理の際に、どのトークン(単語)がそれと類似しているのかどうか、つまり関連しているのかどうかを計算しているのだ。ここでは「マウス」と「クリック」の内積の値が相対的に大きくなっている。「マウス」が生物なのかPC周辺機器なのかは「クリック」という単語から判断でき関連していると考えることができるからである。(ここではそうなるように僕が調整した値だが、イメージ的にはそういうこと)

続いての工程、スケーリングの部分である。ここでは、内積による類似度を\(\sqrt{d_k}=2\)で割ることでスケールの調整を行う。

 

Attentionではなぜ次元数の平方根で割るのか

なんで次元数の平方根で割るの??

\(d_k\)次元のベクトル\(Q,K\)が平均0、分散1(\(\mu=0,\sigma^2=1\))の独立した確率変数であると考えてみると

その内積は\( Q K^{\top}=\sum_{m=1}^{d_k} q_m k_m \)で、平均0、分散\(d_k\)になる。(分散の線形性)

つまり次元が増えれば増えるほど、分散が大きくなってしまう。

ここで使うべきなのが、確率変数の標準化である。

確率変数の標準化

確率変数\(X\)の平均\(\mathrm{E}[X]=\mu\)、分散\(Var(X)=E\left[(X-\mu)^2\right]=\sigma^2\)が存在するとき

$$
Z=(X-\mu) / \sigma
$$

と置くと、\(\mathrm{E}[Z]=0, \operatorname{Var}(Z)=1\)となる。これを確率変数の標準化と呼ぶ。

\(\mu=0\)なので、あとは分散\(d_k\)の平方根で割れば、標準化を達成でき分散が1になるので、

もとの\(Q,K\)の平均、分散と同じになる。

そして、これは次に通すSoftmax関数のグラフである。

三つのパラメータを引数に持つSoftmax関数のグラフ。x_1を変動させたときのsoftmax関数の値を計算したもの。x_2=1,x_3=2と固定し、x_1は-10から10までの範囲で0.1刻みで変動させそれぞれをプロットしたもの。

x_2=1,x_3=2という風に少し大きめに0からずらしてしまっているが

この赤枠の中ではかなり大きく三つのプロットが変化しているが、青枠の中では変化の勾配が小さくなってしまっている。これによって学習が進まなくなってしまう現象が起こってしまう。これを勾配消失という。

分散が大きいということは青枠の中に入りがちになってしまうということである。

なのでこれを防ぐべく赤枠の中に何とか収めたいという気持ちの標準化である。

続いて、スケーリングしたのちに、さらにSoftmax関数を使用してそれぞれの値を0から1の間にして、その総和が1になるように値を調整する。

スケーリング、Softmax関数この二つの処理によって、内積だった値がさらに類似度としての意味合いを強めていく。

内積の値が仮に大きめの負の値であったとすると、Softmax関数が指数関数を通す性質上、0に近い正の値に整えられる。つまり、単語間の関連度が低い場合その値は0に近くなる。逆に関連度が高い場合は1に近くなり、直感的にも分かりやすい。

 

つづいて、その値にそれぞれのValueベクトルの値を掛け合わせる。

 

そして、得られたベクトルを足し合わせる。これが一つ目「マウス」のトークンの出力値となる。

この出力値は、三つ目のトークン「クリック」Valueベクトルの情報をたくさん保持していると考えることができる。

これが、最終的にScaled Dot-Product Attentionから出力される行列の一行目の出力が完了した。つまり、これが一つ目のトークン「マウス」に対応するベクトル情報である。

 

なので、つづいて二つ目のトークンについての計算を行わなければならない。

一つ目のトークンの時と同様、線形変換によって得たQuery行列の二行目のQueryベクトルを使用する。

そして、先ほどと全く同じ計算プロセスで二つ目のトークン「で」の出力値を算出する。

 

最後に三つ目のトークン「クリック」に関する計算を行う。これまでと全く同様、線形変換によって得たQuery行列の三行目のQueryベクトルを使用して

同じ計算プロセスで「クリック」の出力値を計算する。

そして最後にこれまで出力した横ベクトルを縦につなぎ合わせると

 

これによって先ほど得たScaled Dot-Product Attentionの出力を得ることができた。

Attention Matrix

このScaled Dot-Product Attentionの過程において最も重要なのは単語間の関連度の測定である。ここまで読んでくれた方ならもうわかると思うが、それを担っているのはSoftmax関数と通した直後のシーケンス長×シーケンス長の正方行列である。

この正方行列がそのまま単語間の関連度を表している。この行列をAttention Matrixと呼ぶことがある。

 

アテンションマップは、ニューラルネットワークモデルが、入力された文中の各トークンがどの程度他のトークンに影響を受けているか、またはどの程度他のトークンに影響を与えているかを視覚的に表現した行列である。具体的には、あるトークンの行と別のトークンの列が交差する位置にある数値が、その二つのトークン間のアテンションの重みを示している。

「マウス」の行と「クリック」の列にある要素は、「マウス」というトークンが「クリック」というトークンにどの程度影響を受けているか(またはその逆)を示している。一方、「クリック」の行と「マウス」の列にある要素は、「クリック」というトークンが「マウス」というトークンにどの程度影響を受けているか(またはその逆)を示している。

例えば、あるトークン「A」の行と別のトークン「B」の列が交差する位置にある数値が高ければ、トークン「A」がトークン「B」から強いアテンション(注意)を受けていることを意味する。

 

アテンションマップを解釈する際は、以下のポイントに注意すると良いだろう。

  • 数値が高いほど、その二つのトークン間の関連性が強いとモデルが判断していることを示している。
  • 一つのトークンに対して全てのトークンのアテンションの重みを横方向に合計すると、通常は1になる(あるいはそれに近い値)。これはモデルがトークン間の関連性を正規化しているためである。
  • アテンションマップは対称ではないことが多く、つまり「トークンAから見たトークンBの重要性」と「トークンBから見たトークンAの重要性」は必ずしも一致しない。

実際にAttention Matrixを計算してみた。

Q,K,Vの意味を改めて考える

ここまで、理解していただけたのであればQ,K,Vそれぞれの役割は分かったと思う。改めてここで整理しておこう。

 

  • Query:その名の通り「問い合わせ」で、関連の強い別のトークン(自分自身も含む)を探すためのもの
  • Key:自分自身や別のトークンのQueryからの問い合わせを受ける「索引」である。
  • Value:そのトークンから得られる「」。これがその類似度に応じて伝播されていく。

なんだか伏線回収みたいで熱くなるものがある気がする。

 

Scaled Dot-Product Attentionの長所・短所

長所

  1. シンプルな構造: Scaled Dot-Product Attentionは内積とスケーリング、softmax関数を基本とするため、構造が非常にシンプルである。このシンプルさにより、実装が容易で、理解しやすい。
  2. 効率的な計算: この機構は、QueryとKeyの間の類似度を一括で計算できるため、特に大規模なデータセットにおいて計算が効率的である。並列処理が容易なため、GPUなどのハードウェアをフル活用できる。
  3. 文脈の捉え方: Queryと文中の各単語(Key)との類似度に基づくため、文脈に応じて単語の重要度が動的に変化する。これにより、文の意味をより豊かに理解することができる。

短所

  1. 学習パラメータの不在: Scaled Dot-Product Attentionには学習パラメータが存在しないため、モデルが文脈の細かいニュアンスを学習するのに限界がある。特に、多様な特徴部分空間における注意表現を捉えることができない。つまり、注意を一つの意味にしか払えない。
  2. 固定された注意機構: この機構は、固定されたアルゴリズムに基づいているため、異なるタイプのデータや特定のタスクに対して最適化することが難しい。例えば、特定の文脈や言語特有の特性に柔軟に適応することが困難である。
  3. 長いシーケンスへの対応: 長いシーケンスに対しては、全ての単語間の類似度を計算する必要があるため、計算量が大きくなる。特に、メモリ使用量が問題になることがある。これは内積の計算のところで計算量がシーケンス長\(n\)の二乗に比例してしまう。

我々が普段文章を読むとき、文章を様々な観点から理解している

日本語を英語に翻訳するとき、人称は?時制は?場所は?クリックして何をしたの?肯定文なのか疑問文なのか命令文なのか?などなど様々な要素を把握しながら変換する。

これを今回の機械翻訳でも行いたいわけだが、Scaled Dot-Product Attentionでは内積を行う性質上、注意を一つの意味にしか払えない。

 

Scaled Dot-Product Attentionは、そのシンプルさと効率性により広く用いられているが、柔軟性と複雑な文脈の理解には制限がある。これらの短所を克服するために、Multi-Head Attentionなどの改良された注意機構が開発され延いてはTransformerにつながっていくのである。

三つ目に関しては後続の研究で様々な対応策が出てきている。

Multi-Head Attentionを理解しよう

Multi-Head Attentionは入力の「シーケンス長\(n \times \)トークン表現次元\(d_{model}\)」のサイズで構成される行列Q(Query),K(Key),V(Value)を分割し,h個の並列したヘッドで異なるScaled Dot-Product Attention処理を行った結果を、最後に一つに結合する処理である。

この分割したそれぞれをHead(ヘッド)と呼び、それぞれで線形変換と先ほどのScaled Dot-Product Attentionによる計算が行われる。

ポイントは線形変換する際の\(\mathbf{W}_i^Q,\mathbf{W}_i^K,\mathbf{W}_i^V\)の線形変換行列がそれぞれのヘッドですべて違うことと、そのサイズが\(d_{model}\times d_{model}/h\)である。

ここからは元論文の分割するヘッド数\(h\)は\(h=8\)となっているのでそのように考えて流れを追ってみる。つまり8つのヘッドに分けて計算を行う。

Q,K,V低次元射影変換

ここからはある\(i\)番目のヘッドについての動きを見てみよう。

 

Multi-head Attentionの入力であるQ,K,Vのサイズ「シーケンス長\(n \times d_{model}\)」を「シーケンス長\(n \times d_k=d_v=d_{model}/h=64\)」の行列に横方向に\(i\)番目のヘッド専用の\(W_i^Q,W_i^K,W_i^V\)を用いて低次元化する射影を行う。

\( \begin{aligned} & Q_i'=Q W_i^Q \\ & K_i'=K W_i^K\\
&V_i'=V W_i^V\end{aligned} \)

 

なぜ、dk,dvで分けた?

Q,Kの横ベクトルの次元数を\(d_k\)として、Vの横ベクトルの次元数を\(d_v\)という風に分けて考えている。

これはScaled Dot-Product Attentionの計算上、Q,Kはその内積を取るので、その次元数は同じでなくてはならない。対してその後のSoftmax後のVとの積に関しては次元数が同じでなくてもよい。

なので、\(d_v\)は異なる可能性がありTrandformerにおけるハイパーパラメータの一種であるので分けている

が、ここでは\(h=8\)なので\(d_k=d_v=d_{model}/h\)となる。

 

各ヘッドごとにScaled Dot-Product Attention

Scaled Dot-Product Attentionの大いなる特徴は、その入力の行列のサイズ(Vのサイズ)と出力のサイズが一致しているということである。

つまり、Multi-head Attentionの場合の低次元化した行列Q,K,Vを入力した場合、低次元化した行列が出力される。

低次元化した行列\(Q',K',V'\)が入力として先ほどのScaled Dot-Product Attentionを通ることになる。

$$
\mathbf{Z}_i=\text { Attention }\left(Q_i', K_i', V_i'\right)=\text { Attention }\left(Q W_i^Q, K W_i^K, V W_i^V\right)
$$

これが各ヘッドにおける出力\(\mathbf{Z}_i\)である。

Concat:結合を行って元の次元に戻す

 

 

 

それぞれのヘッドから出力された「シーケンス長\(n \times d_v\)」サイズの行列を横方向にConcat(結合)する。こうすることによって結合した後の行列は「\( n \times d_v \cdot h =n \times d_{model}\)」となる。つまり元のMulti-head Attentionの入力の行列のサイズに戻る。(\(h \times d_v = 8 \times 64 = 512\)次元)

$$
\operatorname{Concat}\left(Z_1, Z_2, Z_3 \cdots, Z_h\right)
$$

全結合層で線形変換

そして結合した行列\(\operatorname{Concat}\left(\mathbf{Z}_1, \mathbf{Z}_2, \ldots, \mathbf{Z}_8\right)\)に対し、「\(d_{model} \times d_{model}\)」のサイズの行列\(W^o\)を右から掛けることによって、最終的な出力を得る。

そしてこの出力は「シーケンス長\(n \times d_{model}\)」サイズの行列になり、元の埋め込み行列と同じサイズに戻る。

$$
\text { MultiHeadAttention }(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{Concat}\left(\mathbf{Z}_1, \mathbf{Z}_2, \ldots, \mathbf{Z}_8\right) \mathbf{W}^O
$$

これは文脈におけるどの意味合いを重視するかを変化させる操作に他ならない

 

これがMulti-head Attentionという機構のすべてである。ここからは行列を分割して行うことの効果について理解を深めよう

分割して処理を行うことの効果

多様な表現を捉えることができる

複数のヘッドに分割する利点は、Q(Query)とK(Key)で内積を計算する際に、様々な関係性を取りこぼさないようにする点にある。

Scaled Dot-Product Attention最大の難点は、内積を取るという性質上、内積の結果が大きくなる成分に大きく依存することである。つまり、ほかの成分同士の関係性が無視される問題が起こる。

実際はこのように分断するわけではない。あくまでイメージ

 

ヘッドに分割するためにQ,K,Vを低次元に射影する線形変換を行い、次元を削減することはWord2vecによって生成された特徴空間における複数の特定の次元に注目し新しい基底ベクトルを生成することになる。

そもそもword2vecの最大の利点は単語をベクトルに変換することで、単語の位置関係を(今回の場合)512次元の線形空間に定義できることで単語同士の演算ができることであった。

「ロンドン」-「イギリス」+「日本」=「東京」
「王」-「男」+「女」=「女王」
「長い」-「熱い」+「熱さ」=「長さ」

例えば、上記のような演算が演算が可能になる。

これは「ロンドン」から「イギリス」を引いたもの。つまり「イギリス」から「ロンドン」に向かうベクトルは「その国の首都を示すベクトル」としてとらえることができる。

このようにword2vecによって生成された線形空間のベクトルは「意味」を持っていると考えることができる。

ということは、その線形空間の基底ベクトルは何らかの意味を持っていて線形変換後の基底ベクトルも何らかの意味を持っているヘッドに分割することはそれらの意味を個別に考慮することになると考えることができる。(と考えられる)

 

本当にそれぞれのヘッドで異なる特徴を捉えられるのかどうかだが、''Attention is all you need''(Vaswani et al, 2017)では次のように異なるヘッドで異なる学習が行われていることを示している。

異なるヘッド(5番目と6番目)で明らかに異なる学習が行われていることがわかる。

出典:Attention is all you need(Vaswani et al, 2017)

しかし、実際には後続の研究で、学習済みのTransformerのMulti-Head Attentionの大部分のヘッドを刈り込んでも(つまり、大部分のヘッド内のAttentionの値を0にしても)大きな性能報告が見られない。という報告がある[Voita19][Michel19]

そのため、Transformerで使われる誤差関数にヘッドの学習の多様性を持たせるような項を加えることでヘッドごとに異なる学習を促すようにすることがある。(ちなみに、Transformerでは一般的に交差エントロピー誤差関数が用いられる)

ニューラルネットワーク1

                  ニューラルネットワークとは ニューラルネットワークは脳の神経細胞(ニューロ ...

続きを見る

例えば以下のような項を加える

ポイント

$$
D_{\text {subpace }}=-\frac{1}{h^2} \sum_{i=1}^h \sum_{j=1}^h \frac{V^i \cdot V^j}{\left\|V^i\right\|\left\|V^j\right\|} \text {. }
$$

各ヘッドのバリュー\(V_i\)でhはヘッド数(\(i = 1,2,\cdots,h\))を示している。

出典:Multi-Head Attention with Disagreement Regularization(Jian Li, Zhaopeng Tu, Baosong Yang, Michael R. Lyu, Tong Zhang,2018)

バリューのベクトル同士のコサイン類似度を計算しており、損失関数に組み込むことでその類似度を最小化するように動くことを意味する。シンプルながら効果がある。

コサイン類似度とは?ベクトルの内積から見る類似度

データ分析や機械学習の分野で広く用いられるコサイン類似度は、ベクトル間の類似性を測定する強力なツールである。   この記事では、コサイン類似度の基本原理から、高次元データでの振る舞い、そして ...

続きを見る

 

アンサンブルの効果が期待できる

Multi-Head Attentionの解釈は先から述べているようにベクトルを分割して「並列で複数行う」ということなので、これは「ランダムフォレスト」のような「アンサンブル学習」を行ったと解釈することができる。これによってロバスト化が期待できるのである。

Multi-Head Attentionでは例えばword2vecの512次元を8分割し、それぞれ64次元ずつをもとにScaled Dot-Product Attentionの処理を行う。

ランダムフォレストでは個々の学習器の学習の際に表型のデータから列と行をランダムに抜き出すことで相関の低い決定木を作成するが、Multi-Head AttentionではWord2Vecのベクトルを分割することでそれぞれ相関の低い計算結果を構築する

 

次元が増えることの悪影響を減らす

内積することの問題点とは次元が大きくなればなるほど内積が0になる確率が高くなるということである。内積の結果の平均が0の分散が小さくなる。

実際に計算すると

コサイン類似度とは?ベクトルの内積から見る類似度

データ分析や機械学習の分野で広く用いられるコサイン類似度は、ベクトル間の類似性を測定する強力なツールである。   この記事では、コサイン類似度の基本原理から、高次元データでの振る舞い、そして ...

続きを見る

となってしまう。次元が100の時でかなり内積の結果が0周辺に偏っている。つまり500次元もあればさらに分散は小さくなり内積の結果が0ばかりになりAttentionの計算ができなくなってしまう。

なので分割して次元を落とすことでアテンションの結果を有意なものにしやすくし、効果が上がるのである。

 

Transformerにおける注意機構の種類

実はこれまで解説してきたのは、自己注意機構というTransformerにおけるもっとも基本的な注意機構で、エンコーダー部分にある。

そしてデコーダー部分には二つの注意機構があり、少しだけ自己注意機構と異なる。何がどう違うのか見てみよう。

 

マスク付き自己注意機構(Masked Multi-Head Attention)とは

Decoderブロックに含まれる二つの注意機構の一つ目「マスク付き自己注意機構」について説明する。

 

Decoder側のブロックの自己注意機構では、マスク処理が行われる。

なぜ、マスク処理が必要になるのかを考えてみよう。そのために、機械翻訳向けのエンコーダー・デコーダー向けモデルを考えてみる

このモデルは、エンコーダーに入力された原言語の\(u_1,u_2,\cdots,u_M\)から目的言語の正解トークン列\(w_1,w_2,\cdots,w_N\)を順に予測しながら訓練される。具体的には、目的言語の正解トークン列\(w_i\)までを予測した状態のとき、訓練は\(u_1,u_2,\dots,u_M\)と\(w_0,w_1,\dots,w_i\)からトークン\(w_{i+1}\)を予測できるようにモデルを更新していくことで行われる。

実際の翻訳タスクやほかの生成推論タスクの時には、このように予測を一つずつ逐次的に行っていく必要がある。つまり、N個のトークンを予測する場合、Transformerに含まれるすべての計算をN回も行う必要がある。しかし、訓練時には効率化高速化のため、正解トークン列\(w_1,w_2,\dots,w_N\)を並列で予測する。そうすることでTransformerの計算を1回で終わらせる。

となると問題が発生してしまう。自己注意機構の魅力である並列処理はトークン列全体から情報を取得するため、例えば\(w_{i+1}\)を予測したいときに\(w_0,w_1,\cdots,w_i\)だけ使いたいところが、答えである\(w_{i+1}\)も入力されるし、さらにそれ以降の\(w_{i+2},\dots,w_{N-1}\)の情報を使ってしまう。

並列化したことによるそのカンニングを防がなければならない。これを解決するために、Decoderはマスク処理を自己注意機構に導入する必要があるのだ。

 

具体的に考えてみよう。「マウスをクリック」という文章が「I click with mouse」という文章に翻訳されるとする。

訓練時には「I click with mouse」という系列が入力される

つまりデコーダーブロックに入ってくる埋め込み行列のサイズは「\( 4 \times d_{model}\)」である。そして出力される行列も「\( 4 \times d_{model}\)」ということになる。

ここで注意しておきたいのは、デコーダーブロックの目的は入力された系列データの次の要素を予測することである。なので同時に入力された場合は次のように

 

というように、その単語の埋め込みベクトルが次の予測に使われるベクトルである。Scaled Dot-Product Attentionを振り返ると、最終的に伝播されていくのはAttention Matrixと掛け合わせたValueベクトルであった。

なので、Attention Matrixを以下のようにする。

 

黒のマスの部分は、自身(Query側のトークン)からすると未来のトークンなので、マスクをかけてアテンションを張れないようにする。

プログラム実装上は、AttentionMapのSoftmax直前直前で該当要素に大きな負の値(例:\(-E+10\))を足す。

例えば一行二列目のマスクが掛かっている部分は「I」というトークンが「click」というトークンにどれだけ影響を受けるかを表している。つまり、一行目の二列目三列目四列目の要素にマスクが掛かっているということは

 

掛け合わせたときに、Clickの予測に使われるベクトルに「click」(正解の情報)や「with」「mouse」などの未来の情報を使われないように学習させることができるのだ。

 

マスク付き自己注意機構はこれだけである。これで、Decoderでの並列処理ができるようになる。

交差注意機構(Crossed Multi-Head Attention)とは

デコーダーの二つ目のAttention機構は「交差注意機構」と呼ばれ、エンコーダーからの出力をデコーダーの入力と組み合わせて使用する。

内部の流れはすべて自己注意機構と同じである。違うのは入力のQ,K,Vがどこから来ているのかである。ここでは、Query(Q)はデコーダーの前のステップの出力から来るが、Key(K)とValue(V)はエンコーダーの出力から来る。この機構により、デコーダーはエンコーダーが処理した情報を活用し、より正確な出力を生成することができる。

エンコーダからの出力をエンコーダーブロックを6回通った埋め込み行列\(X\)という意味で\(X^6\)と書くことし、デコーダーブロックの前からくる行列を\(Z\)と書くことにすると。

$$
Q=Z, K=X^6, V=X^6
$$

これが、交差注意機構の入力になる。ちなみにデコーダーブロックもエンコーダーブロックと同じく6回繰り返すが、何回目の交差注意機構もKとVの入力はすべて同じく\(X^6\)である。

 

最後に復習的に動きを確認しておこう。

 

step
1
線形変換の適用

 

まず、エンコーダーからの出力\(n \times d_{model}\)とデコーダーの入力\(n' \times d_{model}\)に対して、それぞれ別々の線形変換が適用される。これは、Multi-Head Attentionにおける各Headごとの重み行列

$$
\left(W_i^Q, W_i^K, W_i^V\right)
$$

を用いて行われる。この変換により、Query、Key、Valueはそれぞれ\(d_k\)(または\(d_v\))の次元を持つベクトルに変換される。

step
2
Attentionの計算

変換されたQueryとKeyの間で内積を取る。ここで、Queryのサイズは\(n' \times d_k\)、Keyのサイズは\(n \times d_k\)であり、内積の結果として得られるAttention Matrixは\(n' \times n\)になる。この行列は、デコーダーの各位置がエンコーダーの各位置にどれだけ注意を払うべきかを示している。

step
3
スケーリングとSoftmax関数の適用

得られたAttention Matrixに対して、\(\sqrt{d_k}\)で割ってスケーリングし行ごとにSoftmax関数を適用し、正規化する。これにより、各行の要素の和が1になる。

step
4
Valueとの積

次に、Softmax関数を適用したAttention Matrix(\(n' \times n\))と、変換されたValue(\(n \times d_v\))の積を取る。この積は行列の積として計算可能であり、結果として得られる行列のサイズは\(n' \times d_v\)になる。この行列は、デコーダーの各位置からのエンコーダーからの情報の集約を表している。

step
5
結果の結合と最終変換

最後に、Multi-Head Attentionでは、異なるHeadからの出力を結合し、もう一度線形変換を適用することで、最終的な出力を得る。

 

まとめ

この記事ではなるべくわかりやすく。特に行列の形にAttentionを払いながら画像をふんだんに交えて解説を行いました。

単に流れを追うだけではつまらないので、なるべく詳しくなぜその操作を行うのか。現時点で僕にできる限りの考察しました。

何か不備やわかりにくいところなどがあればコメント等で教えていただきたいです。

 

この記事が昨今を席巻するTransformer、ひいてはAIを理解したいと思う人の一助となれば幸いです。

参考先

 

-数学, 時系列分析, 線形代数, 自然言語処理