再戰Transformer!原作者帶隊的Mamba 2來了,新架構訓練效率提升

機器之心報道

機器之心編輯部

自 2017 年被提出以來,Transformer 已經成爲 AI 大模型的主流架構,一直穩居語言建模方面 C 位。

但隨着模型規模的擴展和需要處理的序列不斷變長,Transformer 的侷限性也逐漸凸顯。一個很明顯的缺陷是:Transformer 模型中自注意力機制的計算量會隨着上下文長度的增加呈平方級增長。

幾個月前,Mamba 的出現打破了這一局面,它可以隨上下文長度的增加實現線性擴展。隨着 Mamba 的發佈,這些狀態空間模型 (SSM) 在中小型規模上已經實現了與 Transformers 匹敵,甚至超越 Transformers。

Mamba 的作者只有兩位,一位是卡內基梅隆大學機器學習系助理教授 Albert Gu,另一位是 Together.AI 首席科學家、普林斯頓大學計算機科學助理教授 Tri Dao。

Mamba 面世之後的這段時間裡,社區反應熱烈。可惜的是,Mamba 的論文卻慘遭 ICLR 拒稿,讓一衆研究者頗感意外。

僅僅六個月後,原作者帶隊,更強大的 Mamba 2 正式發佈了。

總體而言,本文提出了 SSD(state space duality)框架,基於此,研究者設計了一個新的體系架構 Mamba-2,其核心層是對 Mamba 的選擇性 SSM 的改進,速度提高了 2-8 倍,同時在語言建模方面繼續與 Transformers 競爭。

Tri Dao 表示,他們構建了一個豐富的 SSD 理論框架,許多線性注意力變體和 SSM 是等效的,由此產生的模型 Mamba-2 比 Mamba-1 更好、更快。

Mamba-2 的新算法使其能夠利用更大的狀態維度 (16 → 256),同時訓練速度更快。在需要更大狀態容量的任務上,例如 MQAR 任務,它比 Mamba-1 有了顯著的改進。

此外研究者還發現,最近新出的混合模型(Jamba、Zamba)增加了一些注意力層來提高模型質量。基於這些發現,研究者將 4-6 個注意力層與 Mamba-2 層混合,其表現優於 Transformer++ 和純 Mamba-2,因而得出注意力和 SSM 是互補的。

這項研究的貢獻概括爲:

本文展示了狀態空間模型與一類稱爲半可分矩陣的結構化矩陣族之間的等價性。這一聯繫是 Mamba-2 框架的核心,揭示了狀態空間模型的新屬性和算法。

本文顯著改進了線性注意力理論,首先通過張量收縮的語言對其循環形式提供了一個明確的證明,然後將其推廣到一種新的結構化掩碼注意力(SMA)家族。

本文將 SSM(狀態空間模型)和 SMA(結構化掩碼注意力)聯繫起來,顯示它們有一個很大的交集,彼此是對偶的,同時具有 SSM 式的線性形式和類似注意力的二次方形式。本文還證明了任何具有快速循環形式的核注意方法都是 SSM。

除了內在的理論價值外,研究者所提出的框架爲理解和改進序列模型開闢了廣闊的方向。

在算法層面。所提框架爲計算 SSM 提供了新的高效且易於實現的算法。本文提出了一種基於半可分離矩陣塊分解的 SSD 算法,該算法利用了 SSM 線性遞推和二次對偶形式,在所有主要效率軸上獲得了最優的權衡。基於 SSD 的實現比 Mamba 的優化選擇性掃描實現快 2 到 8 倍,同時允許使用更大的循環狀態大小(是 Mamba 的 8 倍甚至更高,且幾乎不影響速度)。SSD 與優化過的 softmax 注意力實現(FlashAttention-2)具有高度競爭力,在序列長度 2k 時性能相當,在序列長度 16K 時速度快 6 倍。

架構設計。採用 SSM 等新架構的一個主要障礙是針對 Transformers 量身定製的生態系統,例如用於大規模訓練的硬件高效優化和並行技術。本文框架允許使用已建立的慣例和技術來構建 SSM 的架構設計選擇詞彙表,並進一步改進它們。

本文還對 Mamba 塊做了一些修改,這些修改允許實現張量並行,其主要思想包括引入分組值注意力 (GVA,grouped-value attention) 頭結構。

將修改後的並行 Mamba 塊與作爲內部 SSM 層的 SSD 結合使用,形成了 Mamba-2 架構。研究者在與 Mamba 相同的設置中研究了 Mamba-2 的 Chinchilla 擴展法則,發現它在困惑度和實際運行時間方面均優於 Mamba 和 Transformer++。研究者還在 Pile 數據集上訓練了一系列 Mamba-2 模型,結果顯示 Mamba-2 在標準下游評估中匹配或超過 Mamba 和開源的 Transformers。例如,在 Pile 上訓練了 3000 億 token 的 2.7B 參數的 Mamba-2 在性能上超過了在同一數據集上訓練的 2.8B 參數的 Mamba 和 Pythia 以及 6.9B 參數的 Pythia。

系統優化:SSD 框架連接 SSM 和 transformer,允許利用爲 transformer 開發的豐富的系統優化工作。

SSD 層

Mamba-2 的核心貢獻是新的 SSD(state space dual)層。SSD 層可以被定義爲選擇性 SSM 的特例。與 Mamba 相比,Mamba-2 的改動會略微降低表達能力,但卻顯著提高了訓練效率,特別是允許在現代加速器上使用矩陣乘法單元。

SSD 層的對偶注意力:

除了最新的 SSD 層,研究者也對 Mamba 的神經網絡架構做了一些小的改變,Mamba-2 架構如下所示。

Mamba-2 在網絡架構上的主要變化是從順序生成變爲並行生成 SSM 參數,並且 Mamba-2 更適合張量並行等擴展方法。

通過提供狀態空間模型的顯式矩陣變換形式,研究團隊揭示了理解和使用它們的新方法。從計算的角度來看,任何計算狀態空間模型前向傳播的方法都可以看作是半可分離矩陣上的矩陣乘法算法。半可分離矩陣視角爲 SSD 提供了一個視角,其中雙重模式分別指的是線性時間半可分離矩陣乘法算法和二次時間樸素矩陣乘法。

研究團隊定義了結構化狀態空間模型和結構化注意力,討論了它們的屬性,並表明它們都有二次算法和線性算法。

自最初的 Mamba 論文研究了合成任務 —— 如:合成複製和歸納 Head 以來,許多後續工作開始研究更難的關聯回憶任務。由 Zoology 和 Based 系列工作引入的 MQAR(multi-query associative recall)任務已成爲事實上的標準。

通過運行一個比文獻中通常報告的版本要難得多的任務,該團隊發現 Mamba-2 明顯優於 Mamba-1,而改善性能的一個原因是狀態大小(比 Mamba-1 大約 16 倍)。

在這篇文章中,作者深入探討了模型背後的理論。

從兩個完全不同的角度推導出 SSD 的「對偶性」:

SSD 框架提供了狀態空間模型、注意力機制和結構化矩陣之間豐富的聯繫。

雖然 SSD 模型可以被視爲框架內每個分支的具體實例,但 SSD 框架本身更加通用,爲未來的工作開闢了許多方向。

SSD 框架(紅色,藍色):狀態空間模型(即半可分矩陣)和結構化掩碼注意力機制包含了大量高效的序列模型。它們的交集是 SSD 模型(紫色)。

SSD 算法

通常,矩陣乘法(matmul)的 FLOPs 速度要比非矩陣乘法 FLOPs 快得多(高達 16 倍):A100 GPU 具有 312 TFLOPS 的 BF16 矩陣乘法性能,但只有 19 TFLOPS 的 FP32 算術性能,而 H100 具有 989 TFLOPS 的 BF16 矩陣乘法性能,但只有 67 TFLOPS 的 FP32 算術性能。

Mamba-2 的主要目標之一是「利用張量核心加速 SSM」。

在綁定參數並引入 Head 結構後,Mamba-1 中的 SSM 變成了 SSD,這是一種更具限制性的形式,具有類似注意力的公式。並且由於 SSD 連接 SSM 和結構化矩陣,計算 SSM 的高效算法直接對應於「token-mixing」或「sequence-mixing」矩陣 M 的不同分解。

因此,可以通過尋找替代的矩陣乘法方式,例如通過各種方式對其進行分解,從而創建計算 SSM 的新算法。

通過精心選擇塊大小,對這個矩陣進行簡單塊分解,就可以集 SSD 線性遞歸和二次注意力對偶形式的兩種優勢於一身。

而這也就是 SSD 算法的起源,它有 4 個步驟,並且對於這個算法有兩種完全不同的詮釋。

SSD 算法:分塊矩陣分解

首先將半可分 SSM 矩陣劃分爲大小爲 Q×Q 的塊,然後,利用半分矩陣的性質來分解每個低秩的非對角塊:

SSD 算法:分塊和狀態傳遞

該算法的另一種詮釋涉及「推理 SSM 如何在實際序列上進行操作」。

首先將輸入序列分割成大小爲 Q 的塊,步驟可以分爲:

可以看到,大部分算法(步驟 1、2 和 4)利用了矩陣乘法(因此利用了張量核心),而且可以並行計算。

只有步驟 3 需要掃描,但它只操作一個非常短的序列,通常只需要很少時間。

系統及擴展優化

張量並行

使用張量並行對 Mamba-1 進行大規模訓練的一項困難是,每層都需要 2 次 all-reduce,而在 Transformer 中,每個注意力或 MLP 層只需 1 次 all-reduce。這是因爲 SSM 的一些參數是內部激活的函數,而不是層的輸入函數。在 Mamba-2 中,由於採用了「並行投影」結構,所有 SSM 參數都是層輸入的函數,因此可以輕鬆地將張量並行應用於輸入投影:將輸入投影和輸出投影矩陣分割成 2、4、8 個碎片,具體取決於張量並行度。研究者使用 grouped norm,分組數除以張量並行度,這樣每個 GPU 都能單獨完成歸一化。這些變化導致每層只需 1 次 all-reduce,而不是 2 次。

序列並行

在對超長序列進行訓練時,可能需要沿着序列長度進行分割,並將不同部分分配給不同的設備。序列並行主要有兩種形式:對於殘差和歸一化操作,用 reduce-scatter、殘差 + 歸一化、然後 all-gather,取代張量並行中的 all-reduce。由於 Mamba-2 使用與 Transformer 相同的殘差和歸一化結構,因此這種形式的序列並行無需修改即可直接應用。對於注意力或 SSM 操作,又稱上下文並行(CP)。對於注意力,可以使用環形注意力沿序列維度進行分割。對於 Mamba-2,SSD 框架再次提供了幫助:使用相同的蒯分解,可以讓每個 GPU 計算其本地輸出和最終狀態,然後在更新每個 GPU 的最終輸出之前,在 GPU 之間傳遞狀態(使用發送 / 接收通信原語)。

實驗結果

該研究在 MQAR 的一種具有挑戰性的版本上,使用更難的任務、更長的序列和更小的模型進行了對比實驗。基線包括標準的多頭 softmax 注意力以及 Based 架構,實驗結果如圖 8 所示。

下表顯示了 Mamba-2 在一系列下游零樣本評估任務上的性能:

感興趣的讀者可以閱讀論文原文,瞭解更多研究內容