Faster Transformer 是NVIDIA 針對Transformer推理提出的性能最佳化方案,這是一個BERT Transformer 單層前向計算的高效實現,代碼簡潔明了,後續可以通過簡單修改支持多種Transformer結構。
Faster Transformer已經開源。
基本介紹
- 外文名:Faster Transformer
- 開發商:NVIDIA
- 源碼模式:開源
- 軟體類型:性能最佳化方案
產品背景,產品功能,使用方法,
產品背景
BERT是2018年10月由Google推出的一個大型計算密集型模型,其為自然語言理解奠定了基礎。通過微調,它可以套用於廣泛的語言任務,如閱讀理解、旬判情感分析或問答。
而Transformer是一種通用高效的己嚷霉特徵抽取器,2017年12月在Google的論文“Attention is All You Need”中被首次提出,將其作為一種通用高效的特徵抽取器。
Transformer已被多種NLP模型採用,比如BERT以及XLNet,這些模型在多項NLP任務中都有突出表現。
在NLP之外, TTS,ASR等領域也在逐步採用Transformer。可以預見,Transformer這個簡潔有效的網路結構會像CNN和RNN一樣被廣泛採用。
然而,雖然Transformer在多種場景下都有優秀的表現,但是在推理部署階段,其計算性能卻受到了巨大的挑戰:以BERT為原型的多層Transformer模型,其性能常常難以滿足線上業務對於低延遲(保證服務質量)和高吞吐(考慮成本)的要求。以BERT-BASE為例,超過90%的計算時間消耗在12層Transformer的前向計算上。
因此,一個高效的Transformer 前向計算方案,既可以為線上業務帶來降本增效的作用,也有利於以Transformer結構為核心的各類網路在更多實際工業場景中落地。
於是,NVIDIA隊針對Transformer推理提出了性能最佳化方案——FasterTransformer。Faster Transformer是一個開源的高效Transformer實現,相比TensorFlow XLA 可以帶來1.5-2x的提速。
產品功能
Faster Transformer底層由CUDA和cuBLAS實現,支持FP16和FP32兩種計算模式,其中FP16可以充分利用Volta和Turing架構GPU上的Tensor Core計算單肯察院元。
Faster Transformer共接收4個輸入參數。首先是attention head的數量以及每個head的維度。這兩個參數是決定Transformer網路結構的關鍵參厚譽櫻數。這兩個參數的動態傳入,可以保證Faster Transformer既支持標準的BERT-BASE(12 head x 64維),也支持裁剪過的模型(例如,4 head x 32維),或者其他各式專門定製化的模型。其餘兩個參數是Batch Size 和句子最大長度。出於性能考慮,句子最大長度固定為最常用的32,64 和128三種,未來會支持任意長度。恥促遷
Faster Transformer對外提供C++ API, TensorFlow OP,以及TensorRT Plugin三種接口。
使用方法
1. 在TensorFlow中使用Faster Transformer
在TensorFlow中使用Faster Transformer最為簡單。只需要先import .so檔案,然後在代碼段中添加對Faster Transformer OP的調用即可。
2. 使用C++ API或者TensorRT 調用Faster Transformer
考慮到封裝成TensorFlow OP會引入一些額外的開銷,建議用戶直接使用C++ API或者TensorRT Plugin的方式進行集槳斷拔婚成。這兩種方式不支持直接解析訓練好的模型。Transformer層所需要的weights參數,需要用戶手動從訓練好的腳盼擊譽模型中導出。調用方式相對簡單,將導出的weights賦值給參數結構體,創建相應的對象,調用initialize或者build_engine函式初始化對象。運行時,每次調用forward或者do_inference即可。
在TensorFlow中使用Faster Transformer最為簡單。只需要先import .so檔案,然後在代碼段中添加對Faster Transformer OP的調用即可。
2. 使用C++ API或者TensorRT 調用Faster Transformer
考慮到封裝成TensorFlow OP會引入一些額外的開銷,建議用戶直接使用C++ API或者TensorRT Plugin的方式進行集成。這兩種方式不支持直接解析訓練好的模型。Transformer層所需要的weights參數,需要用戶手動從訓練好的模型中導出。調用方式相對簡單,將導出的weights賦值給參數結構體,創建相應的對象,調用initialize或者build_engine函式初始化對象。運行時,每次調用forward或者do_inference即可。