Python程序到計算圖一鍵轉化,詳解清華開源深度學習編譯器MagPy
張晨,清華大學計算機系高性能所博士生,導師爲翟季冬老師,主要研究方向爲面向人工智能和量子計算的高性能異構計算系統。在OSDI、SC、ATC、ICS會議上發表一作論文,並獲得 ICS21 最佳學生論文。曾獲得 SC19, SC20, ISC21 國際超級計算機競賽冠軍。獲清華大學本科生特等獎學金、國家獎學金、北京市優秀畢業生、北京市優秀畢業設計等榮譽。
2024 年 7 月,清華大學計算機系 PACMAN 實驗室發佈開源深度學習編譯器 MagPy,可一鍵編譯用戶使用 Python 編寫的深度學習程序,實現模型的自動加速。
儘管目前存在大量高性能的深度學習編譯器,但是這些編譯器均以計算圖作爲輸入,需要由用戶將編寫的 Python 程序手動轉化爲計算圖。爲了避免這種不便性,該團隊設計了 MagPy,直接面向用戶編寫的 Python+PyTorch 程序,自動將其轉化爲適用於深度學習編譯器的計算圖表示,從而充分發揮深度學習編譯器的優化能力,避免用戶使用複雜 Python 語法帶來的性能下降,爲用戶帶來易用性和效率的雙豐收。
該工作同時於系統領域重要國際會議 USENIX ATC’24 發表長文,第一作者清華大學博士生張晨、通訊作者爲翟季冬教授。PACMAN 實驗室在機器學習系統領域持續深入研究。MagPy 是繼 PET、EINNET 等工作後在深度學習編譯器上的又一次探索。欲瞭解更多相關成果可查看翟季冬教授首頁:https://pacman.cs.tsinghua.edu.cn/~zjd
研究背景:深度學習計算圖提取技術
近年來,深度學習在生物科學、天氣預報和推薦系統等多個領域展示了其強大能力。爲了簡化編程過程,用戶傾向於使用 Python 編寫深度學習模型,並在需要進行張量操作時調用如 PyTorch 等的張量庫。此時,用戶程序會在調用張量庫時立即執行張量操作,如此不加優化地直接執行程序性能較差。另一方面,爲了提升深度學習模型的運行速度,深度學習編譯器傾向於使用以算子圖的格式表示的深度學習模型作爲輸入,在計算圖上進行圖級優化,如圖替換和算子融合。當可以獲取到模型的計算圖時,代表性的深度學習編譯器 TorchInductor 和 XLA 可以在 PyTorch 的基礎上平均加速模型 1.47 倍和 1.40 倍。
具體結果如圖 1 所示,標記爲 Fullgraph-Inductor 和 Fullgraph-XLA。然而,實現這種加速的前提是用戶手動將程序轉換爲計算圖格式,這對許多模型開發者來說是困難的。尤其是隨着深度學習的廣泛應用,越來越多的模型是由化學、生物和天文學等領域的非專業程序員開發的。因此,迫切需要一種自動化方法將用戶編寫的 Python 程序轉換爲編譯器友好的圖格式來加速程序,這被稱爲計算圖提取技術。
由於 Python 程序具有極強的動態性,加之用戶程序存在行爲的不確定性,現有的計算圖提取技術在處理較複雜的用戶程序時無法取得最優的性能,如圖 1 中的 TorchDynamo-Inductor(使用 TorchDynamo 提取計算圖,使用 TorchInductor 編譯)、 LazyTensor-XLA(使用 LazyTensor 追蹤計算圖,使用 XLA 編譯)所示。
圖 1 :深度學習編譯器可以顯著提升模型運行效率,但現有的圖提取技術阻礙了這一點。圖中 Eager 表示直接執行 PyTorch 程序,Fullgraph-Inductor 與 Fullgraph-XLA 分別表示 Inductor、XLA 對模型的計算圖進行編譯後的加速,TorchDynamo-Inductor 與 LazyTensor-XLA 分別表示使用 TorchDynamo 和 LazyTensor 技術從用戶 Python 程序中提取計算圖再進行編譯的性能。
MagPy 的解決方案
MagPy 的核心思想是分析 Python 解釋器中的執行狀態信息,從而讓編譯器能夠更好的理解用戶程序。Python 解釋器能夠準確支持所有 Python 特性,並在運行時保留了高層次的執行狀態信息,如各個變量的類型和值等等。通過有效利用解釋器提供的信息,能夠更全面地瞭解程序的行爲,從而更好地提取程序計算圖。
MagPy 的設計基於以下幾點觀察:
首先,大多數深度學習程序的動態性是有限的。儘管這些程序是用 Python 編寫的,具有數據類型、控制流邏輯和運行時函數調度等潛在的動態特性,但其計算圖結構在不同批次間通常保持不變。ParityBench 是一個從 Github 上自動爬取超過 100 顆星的 PyTorch 深度學習程序組成的基準測試集,它的 1421 個程序中,83% 的程序(1191 個)均滿足有限動態性的性質。對於這些程序,可以通過在程序執行過程中監控張量操作,較爲簡便地獲取其計算圖。根據這個性質,MagPy 將計算圖提取問題從分析 “計算圖是什麼” 簡化爲分析 “得到的計算圖何時會發生變化”。
其次,只有外部值能影響程序行爲。利用這一特性,可以更簡易地檢測出會導致計算圖發生變化的因素。這裡的 “程序行爲” 包括計算圖的結構和所有程序副作用(side effect)。只要程序從外部讀取的所有值(如輸入參數和全局變量)保持不變,且調用的函數的輸出結果不具有隨機性,程序行爲就不會發生變化。因此,MagPy 只需驗證所有從外部讀取的值都不變,即可保證計算圖結構不變。例如,儘管圖 2 中的程序使用了許多複雜的 Python 特性,但只要所有從外部讀取的值(如 x、dims、self.scale 和 self.dim,標記爲粗體)與之前運行一致,計算圖就保持不變。MagPy 會首先運行一個 “守衛函數” 對於這些值是否發生變化進行檢查(Guards),當檢查通過時,MagPy 將會運行一個 “模擬函數”(mock code),用以調用經過深度學習編譯器編譯的計算圖及模擬程序的所有副作用(如示例中的對 x 進行賦值)。
第三,守衛函數和模擬函數都可以通過分析程序執行狀態來確定。守衛函數的作用是驗證新一次執行的輸入狀態是否與之前運行匹配,模擬函數的目的是重現之前運行的最終狀態。這兩個部分僅基於運行時狀態,而不是用戶程序的邏輯。Python 解釋器在解釋執行程序的過程中,保留了所有需要的執行狀態信息,因此不再需要具體分析 Python 複雜而動態的執行邏輯。守衛函數和模擬函數需要關注的變量包括顯式讀取或寫入外部的值(如 self)以及被它們引用的值(如 self.dim)。因此,MagPy 設計了引用關係圖來記錄和分析程序行爲。
基於上述觀察,MagPy 提出了引用關係圖(Reference Graph,簡寫爲 RefGraph)來記錄程序執行期間的程序狀態。MagPy 定義了執行狀態接口,用於在程序執行期間收集運行時信息,並使用基於標註的圖更新規則來維護 RefGraph。MagPy 還提出了在 RefGraph 上進行遍歷生成守衛函數和模擬函數的算法。具體細節可以閱讀論文。
實驗
MagPy 具有極高的 Python 語言特性覆蓋率,其在對 ParityBench 中 1191 個靜態的真實用戶程序進行測試時,成功將 93.40% 的程序轉化爲完整的操作符圖,大幅高於現有工作 TorchScript(35%)和 TorchDynamo(77.2%)
由於更完整的計算圖導出,MagPy 在端到端測試中,也表現出具有競爭力的性能。下圖展示了對於圖像處理、自然語言處理等典型深度學習模型,MagPy 取得的加速。MagPy 可取得最高 2.88 倍,平均 1.55 倍的提升。實驗在單張 A100 上進行,X-Y 表示使用圖導出技術 X 和圖層編譯器 Y。