大多數的深度學習模型使用的是32位單精度浮點數(FP32)來進行訓練,而混合精度訓練的方法則通過16位浮點數(FP16)進行深度學習模型訓練,從而減少了訓練深度學習模型所需的記憶體,同時由於FP16的運算比FP32運算更快,從而也進一步提高了硬體效率。
基本介紹
- 中文名:混合精度訓練
- 套用領域:深度學習
- 使用技術:32位單精度浮點數
概述,原理介紹,關鍵技術,
概述
深度學習模型的計算任務分為訓練和推理.訓練往往是放在雲端或者超算集群中,利用GPU強大的浮點計算能力,來完成網路模型參數的學習過程.一般來說訓練時,計算資源往往非常充足,基本上受限於顯存資源/多節點擴展/通訊庫效率的問題。相對於訓練過程,推理往往被套用於終端設備,如手機,計算資源/功耗都收到嚴格的限制,為了解決這樣的問題,提出了很多不同的方法來減少模型的大小以及所需的計算資源/存儲資源。模型壓縮除了剪枝以外,還有一個方法就是降低模型參數的數值精度。隨著網路深度的加大,帶來的參數數量也呈現指數級增長,如何將最終學習好的網路模型塞入到終端設備有限的空間中是目前很多性能優良的網路真正套用到日常生活中的一大阻礙。
原理介紹
通過用半精度運算替代全精度運算來提高效率,這一技術原理聽起來很簡單明了,但將其付諸實施並不像聽起來那么簡單。此前也有團隊嘗試過使用更低精度進行混合計算(如二進制,甚至4-bit),但問題在於這往往不可避免地造成結果的準確性和在主要網路變換上的損失,而百度的MPT模型不僅解決了這一問題,更重要的是MPT無需改變網路超參數,並保持與單精度相同的準確性。
在百度研究院部落格中,百度進一步解釋了這一模型的原理:
深度學習模型由各種層(Layer)組成,包括完全連線的層,卷積層和反覆層。層與層之間的轉換可以通過通用矩陣乘法(GEMM)來實現,而對深度學習訓練的過程其實很大程度是GEMM計算的過程。
當使用FP16代表神經網路中的數據時,GEMM操作的輸入矩陣由16位數組成。我們需要可以使用16位計算執行乘法的硬體,但是需要使用32位計算和存儲來執行加法。使用少於32位的加法操作訓練大型深度學習模型會非常困難。
為此,百度不僅與NVIDIA共同解決了硬體支持的問題,雙方還對訓練流程進行了一些修改,模型中的輸入,權重,梯度和激活以FP16格式表示。但是如之前介紹,與FP32數字相比,半精度數字的範圍有限,只是通過簡單地更改存儲格式,某些模型無法達到與單精度相同的精度。
關鍵技術
第一項關鍵技術被稱為“混合精密鑰匙”(mixed precision key)。如下圖所示,在MT模型中仍然保留FP32格式的主副本,將FP16用於正向和反向傳播,最佳化器中的梯度更新將被添加到主FP32副本當中,該FP32副本被簡化為一個FP16副本在訓練期間使用,這個過程在每次訓練疊代中重複,直至模型收斂且足以恢復損失的精度,從而達到較低記憶體使用、記憶體頻寬壓力更低和更快速執行的優點。
第二種關鍵技術則是“損耗縮放”(loss-scaling)。該技術可以夠恢復一些小數值的梯度。在訓練期間,一些權重梯度具有非常小的指數,其FP16格式可能會變為零。為了克服這個問題,我們使用縮放因子在反向傳播開始時縮放損失,通過連鎖規則,梯度也逐漸擴大,並在FP16中可表示。在將其更新套用於權重之前,梯度確實需要縮小;而為了恢復某些型號的精度損失,必須進行損耗調整。關於這兩種技術的更多細節可以在我們的論文中找到。
百度已使用這種方法使用FP16訓練其DeepSpeech 2模型。結果表明,對於英文和國語模型和數據集和使用相同的超參數、模型架構進行混合精度訓練實驗,可以得到到FP32訓練的精度。
同時,使用FP16訓練減少了深度學習模型的記憶體需求,使得百度能夠使用一半的處理器來訓練這些模型,從而有效地加倍了集群大小。此外,FP16算術的峰值性能(如上所述)通常高於單精度計算。