低温の平均場トランスフォーマーでトークン分布が急速に集中することを定量化
この論文は、トランスフォーマーの推論時にトークン(入力の単位)がどう動くかを、大量のトークンを扱う極限で調べた研究です。著者らは、その動きを「平均場の連続方程式」と呼ばれる確率分布の方程式で記述し、低温と呼ばれる設定で分布の集中が起きることを定量的に示しました。低温とは数学的には温度パラメータ β^{-1} が 0 に近づく、つまり β が大きい状況を指します。
研究の中心的な結論はこうです。初期のトークン分布は、キー行列・クエリ行列・バリュー行列(トランスフォーマー内部で使われる行列)から定まるある射影(投影)写像の下で押し出された(push-forward)分布に急速に近づき、その近い状態が中程度の時間では準安定(metastable)に保たれる、というものです。「押し出す(push-forward)」とは簡単に言えば、元のトークンをその写像で変換して得られる新しい分布のことです。分布の近さはワッサースタイン距離という確率分布間の距離で測り、著者らはその距離が次のように評価されると示しました: sqrt(log(β+1)/β)·exp(Ct) + exp(-c t) 。ここで t は推論時間、C と c は正の定数です。式から、β が大きい(低温)とき、時間スケールがログ(β) 程度までは分布が濃縮することが分かります。
証明にはいくつかの解析的手法が使われています。まず零温度(β→∞)の場合の方程式に対してライアプノフ(Lyapunov)型の見積もりを得て、時間が大きくなるとどんな分布に落ち着くかを特定しました。次にワッサースタイン空間での安定性の見積もりと、確率論で使われる定量的ラプラス原理(積分が極大点の周りに集中する性質の定量化)を組み合わせて、零温度方程式と有限だが大きいβの方程式を結び付けています。専門用語を平たく言えば、零温度での振る舞いを理解し、その近さが有限温度でも保たれることを丁寧に示した、ということです。
数値実験も行われています。これらは理論結果を裏付けるものです。同時に、理論が扱わない振る舞いも見つかりました。有限のβ(実用的には無限大でない温度)で長時間を考えると、ダイナミクスは別の終局相(terminal phase)に入り、その振る舞いはバリュー行列の固有値のスペクトル(行列の基本的な周波数特性)に支配されることが示されました。つまり、低温かつ中程度の時間では著者らの主張が当てはまるが、十分に長い時間や温度が高い場合は別の現象が出る可能性があります。
重要な注意点です。この結果は「大トークン数の極限(平均場近似)」と「低温領域(β 大)」という仮定の下での理論的な解析です。対象は深いエンコーダーのみのトランスフォーマーの推論時の振る舞いであり、学習(トレーニング)の過程やすべての実際のモデル設定にそのまま当てはまるとは限りません。理論は特定の時間スケールやパラメータ領域での現象を精密に説明しますが、有限トークン数や高温、長時間挙動については別の分析や追加の数値検証が必要です。