リモート開発メインのソフトウェア開発企業のエンジニアブログです

BERTのモデル構造をもう少し詳しく

以前にBERTに関する記事を投稿しました

第二回で、BERTのモデル構造について簡単に解説を書きました。今回はさらに細かく、レイヤーの中身まで見ていきましょう。

前回の復習

  • BERTの基本部分は、Transformerのエンコーダをスタックしているだけ

BERTの論文には、以下のように述べられています。


要約すると

とあります。

Transformerの中には、Attentionというレイヤーがありますが、Attentionについての解説はググると色々と解説はでてきます。
なので、Attentionのレイヤーを見つけるまであたりの解説、その際の注意点に触れていこうと思います。

Moba Pro

Transformer

というわけで、軽くTransformerの構造について、確認をしておくと良いでしょう。

これの左側のブロックが、エンコーダの部分に相当します。BERTで使うのはエンコーダの方なので、今回はこの部分だけ軽く見ていきます。

まずこの画像では、ブロック内の説明が一部簡略化されています。

  • 各サブレイヤーの出力にdropoutがあります。
  • Add & Normは、足してからNormalization Layerを置くということも注意

このことは論文にも書いてありますし、メジャーな実装を見て確認してみるのも良いと思います。PyTorchにもTransformerのモデルが入っていますので、上記の点を確認して見てください。

図にするとこのような感じになります。

エンコーダ内は2つのブロックがあり、個別のブロックは

  • サブレイヤー
  • Dropout
  • 残差接続
  • Layer Normalization

で構成されています。2ブロックの各々のサブレイヤーは、Feed Forward と Multi-Head Attentionです。

Feed Forward

Feed Forwardは以下の3つで構成されます。

  • dense
  • activation function(活性化関数)
  • dense

活性化関数は、Transformerの論文ではReLUですが、BERTではGeLUが使われています。

Multi-Head Attention

残りの理解すべきレイヤーは、Multi-Head Attentionですね。
Attentionについては冒頭でも述べたとおり、調べると様々な解説がありますのでここでは飛ばします。

日本語で読める参考資料

個人的には、まず計算式のみを追うことをおすすめします。
”辞書”のような役割を果たす、というのは知っておいたほうが良いかとは思いますが、そもそもMulti-head化された上でレイヤーを重ねて複数存在する個別のAttentionのパラメータは、それを見ても人間が理解できるような辞書のような形をしてないと思われます。

BERTの構造を見る

Transformerについては以上までに留めて、BERTのモデルを確認していきます。

メジャーに使われている実際に動くモデルを使って構造を確認したいですね。論文の実装はTensorFlowで、個人的には読みづらいです。
そこで、PyTorch実装であるhuggingface/transformersを見るのがおすすめです。

Transformer系の有名なモデル、それらの事前学習データを簡単にダウンロードできる仕組みも入った使いやすいライブラリです。

さらに、PyTorchのモデルの構造の中身を確認できるtorchinfoも使いましょう

from transformers import BertForPretraining
from torchinfo import summary

# 今回は構造を見るだけなのでpretrainedデータをダウンロードする必要はありませんが、めんどいのでこうやります
model = BertForPretraining.from_pretrained('bert-base-uncased')
summary(model,depth=4)
===========================================================================
Layer (type:depth-idx)                             Param #
===========================================================================
├─BertModel: 1-1                                   --
|    └─BertEmbeddings: 2-1                         --
|    |    └─Embedding: 3-1                         23,440,896
|    |    └─Embedding: 3-2                         393,216
|    |    └─Embedding: 3-3                         1,536
|    |    └─LayerNorm: 3-4                         1,536
|    |    └─Dropout: 3-5                           --
|    └─BertEncoder: 2-2                            --
|    |    └─ModuleList: 3-6                        --
|    |    |    └─BertLayer: 4-1                    7,087,872
|    |    |    └─BertLayer: 4-2                    7,087,872
|    |    |    └─BertLayer: 4-3                    7,087,872
|    |    |    └─BertLayer: 4-4                    7,087,872
|    |    |    └─BertLayer: 4-5                    7,087,872
|    |    |    └─BertLayer: 4-6                    7,087,872
|    |    |    └─BertLayer: 4-7                    7,087,872
|    |    |    └─BertLayer: 4-8                    7,087,872
|    |    |    └─BertLayer: 4-9                    7,087,872
|    |    |    └─BertLayer: 4-10                   7,087,872
|    |    |    └─BertLayer: 4-11                   7,087,872
|    |    |    └─BertLayer: 4-12                   7,087,872
|    └─BertPooler: 2-3                             --
|    |    └─Linear: 3-7                            590,592
|    |    └─Tanh: 3-8                              --
├─BertPreTrainingHeads: 1-2                        --
|    └─BertLMPredictionHead: 2-4                   --
|    |    └─BertPredictionHeadTransform: 3-9       --
|    |    |    └─Linear: 4-13                      590,592
|    |    |    └─LayerNorm: 4-14                   1,536
|    |    └─Linear: 3-10                           23,471,418
|    └─Linear: 2-5                                 1,538
===========================================================================
Total params: 133,547,324
Trainable params: 133,547,324
Non-trainable params: 0
===========================================================================

以上のように、モデル内部のレイヤーを階層で表示してくれます。
このなかのBertEncoderが、エンコーダの部分に相当することがわかります。中にBertLayerが重なっているのがわかります。

今回の目的はBertLayerの中身なので、これを詳しく見ていきましょう。

from transformers import BertConfig, BertLayer
config = BertConfig()
summary(BertLayer(config))
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
├─BertAttention: 1-1                     --
|    └─BertSelfAttention: 2-1            --
|    |    └─Linear: 3-1                  590,592
|    |    └─Linear: 3-2                  590,592
|    |    └─Linear: 3-3                  590,592
|    |    └─Dropout: 3-4                 --
|    └─BertSelfOutput: 2-2               --
|    |    └─Linear: 3-5                  590,592
|    |    └─LayerNorm: 3-6               1,536
|    |    └─Dropout: 3-7                 --
├─BertIntermediate: 1-2                  --
|    └─Linear: 2-3                       2,362,368
├─BertOutput: 1-3                        --
|    └─Linear: 2-4                       2,360,064
|    └─LayerNorm: 2-5                    1,536
|    └─Dropout: 2-6                      --
=================================================================
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
=================================================================

BertLayerの中身は以下のように分かれていることがわかります。

  • BertAttention
  • BertIntermediate
  • BertOutput

これは、前節での話と比べると少し違和感がありますね。
結論を言えば、これは上に説明したTransformerのエンコーダ部分と同じなのですが、huggingface/transformersでは、レイヤーの分割の仕方が異なります。
これを図にすると以下のようになります。

Image from Gyazo

Attentionの方は素直に対応しています。ポイントはIntermediateとOutputの区分です。サブレイヤー内の最後のdenseだけが後ろ側のBertOutputになっています。

なぜこのようになっているのかは、ソースを読んでも分かりませんでした。denseとactivation functionの2つで1つのまとまりとみなしたのかなと思います。

まとめ

BERTに関する解説を探すと、個別の概念についての簡単な紹介はあるものの、全体像が見えずに煙に巻かれたような気持ちになっていました。
実際に動くコードを使って確認をすることでこれが解消できたと思っています。
BERTの中身がよくわからない人の助けになれば幸いです。(とはいえAttentionの説明は投げましたが)


補足:PreNormとPostNorm

補足として、Transformerの論文の解説記事であるThe Annotated Transformerのコードは、論文とは異なった実装になっているという点で初学者にとって罠なので気をつけましょう!!

Layer Normalizationの位置が、残差接続ではなく、サブレイヤー前になっています。

Layer Normalizationをサブレイヤーの前後どちらに置くかで、PreNorm, PostNormと呼ばれます。それに関する論文もあるようです。

今回見たTransformerやBERTはPostNormであり、The Annotated TransformerはPreNormということになります。

BERTの論文内で触れられているtensor2tensorは、当初はPostNormだったようですが、後にこの位置関係を簡単に設定で切り替えられるようにした上で、デフォルトだとPreNormになっていました。

← 前の投稿

SparkでDataFrameの内容を単一のファイルに保存する

次の投稿 →

Ansibleに関するエラーを解消する

コメントを残す