日本語BERTモデルをPyTorch用に変換してfine-tuningする with torchtext & pytorch-lightning
TL;DR
①TensorFlow版訓練済みモデルをPyTorch用に変換した
(→方法だけ読みたい方はこちら)
②①をスムーズに使うための torchtext.data.Dataset
を設計した
③PyTorch-Lightningを使ってコードを短くした
はじめに
日本語Wikipediaで事前学習されたBERTモデルとしては, 以下の2つが有名であり, 広く普及しています:
- SentencePieceベースのモデル (Yohei Kikuta さん提供)
- TensorFlow版
- Juman++ベースのモデル (京大黒橋研提供)
- TensorFlow版
- PyTorch版(Hugging Face transformers準拠)
このうち, SentencePieceベースのものは現在TensorFlow版のみの提供となっており, PyTorch版は存在しません。
そのため, 私のようなPyTorchユーザーでがっくり肩を落とされた方は多いのではないでしょうか?
しかし決して諦めることはありません。
実は, ほんの少し工夫するだけでPyTorch版に変換することは可能です!
早速試していきましょう。
(本記事の手法を試すにあたり, kaggler-ja slackの皆さんには多くの助言をいただきました。この場を借りてお礼申し上げます)
環境
- Google Colaboratory
- Python 3.6.9
- TensorFlow 1.15.0
- PyTorch 1.3.1
- Torchtext 0.3.1
- PyTorch-Lightning 0.5.3.2
実践
0. 下準備
0-1. Yohei Kikutaさん版日本語BERTモデルの取得
こちらに公開されているファイルを取得しておきます。
- BERTモデルのCheckpoint
model.ckpt-1400000.index
model.ckpt-1400000.meta
model.ckpt-1400000.data-00000-of-00001
- BERTモデルのメタグラフ形式 (今回は使用しません)
graph.pbtxt
- SentencePieceモデル
wiki-ja.model
wiki-ja.vocab
ここではGoogle Driveの My Drive/NLP/bert_yoheikikutasan/
直下に保存するものとします。
0-2. Google Driveのマウント
つづいてGoogle Colaboratoryに入り, 仮想マシンにGoogle Driveをマウントします。
from google.colab import drive import pathlib # Google Driveをマウントする仮想マシン上のディレクトリ DIR_DRIVE = pathlib.Path('./gdrive/') # Google Drive上でのNotebook等の各種ファイルのパス DIR_COLAB = DIR_DRIVE / 'My Drive/Colab Notebooks/' DIR_PROJCET = DIR_COLAB / 'livedoor_classification/' # Google Driveをマウント drive.mount(DIR_DRIVE)
標準出力にしたがってアカウント認証と認証コードの入力を行い, マウントを完了させます。
1. 訓練済みBERTモデルの変換
1-0. 方針
つづいて訓練済みBERTモデルをTensorFlow用からPyTorch用に変換していきましょう。
PyTorch側でモデルの"ガワ"だけ作っておき, そこにTensorFlow用モデルの重み行列の中身を流し込むイメージです。
PyTorch用モデルの"ガワ"はゼロから設計はせず, PyTorch用BERT族の定番ライブラリ(Hugging Face Transformers)を利用します。
1-1. Hugging Face transformersの準備
Hugging Face Transformersをインストールし, モデルの枠組みをつくります。
!pip install transformers
from transformers import BertConfig, BertForPreTraining, BertTokenizer, BertModel # configの用意 (語彙数は30522 -> 32000に修正しておく) bertconfig = BertConfig.from_pretrained('bert-base-uncased') bertconfig.vocab_size = 32000 # BERTモデルの"ガワ"の用意 (全パラメーターはランダムに初期化されている) bertmodelforpretraining = BertForPreTraining(bertconfig)
1-2. TensorFlowモデル -> PyTorchモデルの変換
つづいてTensorFlow版BERTの重み行列を読み込み, PyTorch版モデルに読み込みましょう。
これはHugging Face Transformersのメソッド一発で簡単にできます。
なお, この工程でTensorFlowのcheckpointパスが必要になりますが, TensorFlowの文脈で「checkpointのパス」といった場合は .index
, .meta
, .data-XXXXX-of-YYYYY
などの拡張子を除いた部分を指すことに注意が必要です。
DIR_BERT_KIKUTA = DIR_DRIVE / 'My Drive/NLP/bert_yoheikikutasan/' BASE_CKPT = 'model.ckpt-1400000' # 拡張子は含めない # TensorFlowモデルの重み行列を読み込む (数分程度かかる場合がある) bertmodelforpretraining.load_tf_weights(bertconfig, DIR_BERT_KIKUTA / BASE_CKPT) # BERTの本体部分だけ取り出す bertmodel = bertmodelforpretraining.bert
これで無事にPyTorch版日本語BERTを手に入れることができました!
2. Livedoorニュースコーパスでfine-tuningする
2-0. 方針
ここからは, 手に入れた日本語BERTモデルでLivedoorニュースコーパスに対する文書分類タスクを解いてみます。
まずは torchtext を用いて前処理の準備をしていきましょう。
torchtextは主に前処理とミニバッチの切り出しの省力化に特化したライブラリであり, 以下の(1)〜(4)をより少ないコード量で実現することができます。
- (1) train, test用データを1行1サンプルのtsvに変換し, 同ディレクトリに別ファイルとして保存しておく
- (2) 前処理を定義する (
torchtext.data.Field
) - (3) tsvの各カラムに Field を割り当て, 前処理を一括で実行 (
torchtext.data.Dataset
) - (4) 訓練時にミニバッチを自動的に取り出す (
torchtext.data.Iterator
)
2-1. tsvファイルの作成
詳しくはこちらを参照してください。 radiology-nlp.hatenablog.com
Livedoorニュースコーパスは9種類の記事からなるため, ここではtsvのカラムは左から順に元のテキストファイル名, 記事本文, ラベル9個のone-hot encoding の形式としました。
filename article dokujo_tsushin it-life-hack ... topic-news hogehoge.txt fugafuga 1 0 ... 0
2-2. 訓練済みサブワード分割器の読み込み
つづいて, BERTの事前学習で使用されたのと同じサブワード分割器 (SentencePiece) のモデルを読み込み, 復元します。
まず SentencePiece をインストールしましょう。
!pip install sentencepiece
次に SentencePiece モデルを読み込みます。
import sentencepiece as sp BASE_SPM = 'wiki-ja.model' BASE_VOCAB = 'wiki-ja.vocab' # 一旦空の SentencePiece モデルを作成 spm = sp.SentencePieceProcessor() # 読み込み. 成功すると True が返る spm.Load(DIR_BERT_KIKUTA / BASE_SPM)
2-3. 前処理の定義 (torchtext.data.Field
)
前処理を torchtext.data.Field
に定義していきましょう。
基本的に1つのFieldは1種類の前処理しか行うことができません。
このため, 2種類以上の前処理を行う場合はそれぞれについてFieldをつくっておく必要があります。
また, クラスラベル等のように "前処理を何も行わない" カラムに対しても, 前処理を何も行わないことを定義したFieldがやはり必要です。
したがって, 文書分類タスクの場合,少なくとも入力文用とクラスラベル用の2種類のFieldが必要になります。
では前処理をどのように定義するかというと, torchtext.data.Field
のコンストラクタに callable を渡すことでその callable の内容を前処理として実行させることができます。
radiology-nlp.hatenablog.com
ここでは↑の記事で作っておいた分かち書き用クラスを使います。
これを使うと, どのような分かち書き器に対しても同じコードで分かち書きができるようになります。
MAX_LEN = 256 stp = SentencePieceTextProcessor(spm, MAX_LEN) passage = '吾輩は猫である。' stp.to_wordpieces(passage) # ['▁', '吾', '輩', 'は', '猫', 'である', '。'] stp.to_token_ids(passage) # [9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29] stp.to_bert_input(passage) # [4, 9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29, 3, ..., 3, 5]
無事に前処理が定義できたところで, この前処理の機能を搭載したFieldをつくりましょう。
その他に, クラスラベル用に前処理を何も行わないFieldも定義します。
import torch import torch.nn as nn import torch.optim as optim import torchtext # 文を分かち書きしてBERT形式のID列に変換するField field_text = torchtext.data.Field(sequential=True, use_vocab=False, batch_first=True, tokenize=stp.to_bert_input, include_lengths=True) # 何もしないField field_label = torchtext.data.Field(sequential=False, use_vocab=False)
2-4. 前処理の一括実行 (torchtext.data.Dataset
)
これで, データセットのtsvファイルに対して前処理を一括で行う準備ができました。
ここまで来れば, データセットに前処理を施して torch.Tensor
形式に変換したものを短いコードで得ることができます。
まずは, データセットのtsvのどのカラムにどの前処理を割り当てたいかを指定しましょう。
import random PATH_TSV = '' # 3-1. で作成したtsvのパス # tsvの各カラムに割り当てる名前とFieldを指定する # (field_name, torchtext.data.Field) のタプルを容れたリスト # torchtext.data.Fieldは反復使用してよい # field_nameは反復使用不可 N_CLASS = 9 fields_livedoor = [('filename', None), ('text', field_text)] + [('label_{}'.format(i), field_label) for i in range(N_CLASS)]
続いてtorchtext.data.Dataset
のコンストラクタを実行しましょう。
すると, ここでtsvファイルからデータが読み出され, 前処理が実行され, その結果がtorchtext.data.Dataset
オブジェクトに格納されます。
# tsvファイルの各カラムに対応するFieldに割り当てられた前処理が実行され, 結果がtorch.Tensorで格納される ds = torchtext.data.TabularDataset(path=PATH_TSV, format='tsv', skip_header=True, fields=fields_livedoor) # train/val/testを分離 ds_train, ds_val, ds_test = ds.split(split_ratio=[0.8, 0.1, 0.1], random_state=random.seed(42))
2-5. torchtext.data.Iterator
の作成
次に, ミニバッチの切り出しを楽にしてくれる torchtext.data.Iterator
をつくりましょう。
これも短いコードで書くことができます。
BATCH_SIZE = 32 dl_train = torchtext.data.Iterator(ds_train, batch_size=BATCH_SIZE, train=True) dl_val = torchtext.data.Iterator(ds_val, batch_size=BATCH_SIZE, train=False, sort=False) dl_test = torchtext.data.Iterator(ds_test, batch_size=BATCH_SIZE, train=False, sort=False)
イテレーターを回すとミニバッチを取り出すことができます。
ミニバッチはtorchtext.data.Batch
オブジェクトで与えられており, ミニバッチの行列はこのオブジェクトのプロパティとして格納されています。
# 試しにdl_trainを回してみる for batch in dl_train: # プロパティ名はtorchtext.data.Dataset定義時に与えたfield_nameと同じ print(batch.text) # torch.LongTensor of size 32x256 print(batch.label_0) # torch.LongTensor of size 32 print(batch.label_1) # torch.LongTensor of size 32 print(batch.label_2) # torch.LongTensor of size 32 print(batch.label_3) # torch.LongTensor of size 32 print(batch.label_4) # torch.LongTensor of size 32 print(batch.label_5) # torch.LongTensor of size 32 print(batch.label_6) # torch.LongTensor of size 32 print(batch.label_7) # torch.LongTensor of size 32 print(batch.label_8) # torch.LongTensor of size 32
3. 学習する
3-1. BERTのfine-tuningするパラメーターを指定
続いてBERTのどの層のパラメータを固定し, どの層をfine-tuningするかを指定します。
ここではBERT Encoder layer 12層すべてと, それに続くPoolerをfine-tuningすることにしましょう。
# 一旦BERTの全レイヤーのfine-tuningを無効にする for _, param in bertmodel.named_parameters(): param.requires_grad = False # Encoder layerのfine-tuningを有効化 for layer in bertmodel.encoder.layer: for _, param in layer.named_parameters(): param.requires_grad = True # Poolerのfine-tuningを有効化 for _, param in bertmodel.pooler.named_parameters(): param.requires_grad = True
3-2. 学習用コード
ここまで来れば, あとは実際の学習のためのコードを書くだけです! あと一歩!
...と言いたいところですが, PyTorchではこの残りの一歩のためにかなり長いコードを書かなくてはいけません。
(実際に書いたことのある方はお分かりかと思います)
そこで, ここでは学習周りのコードを大幅に簡略化できるラッパーの1つ, PyTorch-Lightningを使っていきましょう。
PyTorch-Lightningとは何者で, 何が嬉しいのかは↓の記事に詳しいです。
qiita.com
まずLightningModule
を継承したクラスに, ネットワーク構造, 与えるデータ, 学習のプロセスを指定していきましょう。
生のPyTorchで書くと, でかいfor文を回したり, train/validation/testで細部を変えたりするのが大変ですが, PyTorch-Lightningはあらかじめ与えられた項目を穴埋めするだけでこれらがすべて完成する仕組みになっています。
import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping class LivedoorClassifier(pl.LightningModule): def __init__(self, bertmodel): # モデルの構造を記述 super().__init__() self.bert_model = bertmodel self.bert_hidden_dim = self.bert_model.config.hidden_size # 768 self.affine = nn.Linear(self.bert_hidden_dim, 9) self.logsoftmax = nn.LogSoftmaxe(dim=1) self.postprocess = nn.Sequential(self.affine, self.logsoftmax) self.lossfunc = nn.NLLLoss(reduction='none') def forward(self, inputs, **kwargs): # モデルの推論を記述 # model_output: size (n_batch, 9) model_output = self.postprocess(self.bert_model(inputs)[1]) return model_output def training_step(self, batch, batch_nb): # trainのミニバッチに対して行う処理 """ (batch) -> (dict or OrderedDict) # Caution: key for loss function must exactly be 'loss'. """ # X: size (n_batch, max_len) X = batch.text[0] # T: size (n_batch, 9) (一旦one-hot vector化する) T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1) # T: size (n_batch) (正解クラスの番号のみ保持) T = torch.argmax(T, dim=1) # GPU使用中ならX, Tを CPU -> GPU に移動させる X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device) T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long() # 各クラスに対する対数尤度: size (n_batch, 9) logPY = self.forward(X) # 損失関数: size (n_batch) loss = self.lossfunc(logPY, T) # 推測したクラス: size (n_batch) Y = torch.argmax(logPY, dim=1).long().detach() progress_bar = {'loss':loss} log = {'loss':loss} returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log} return returns def validation_step(self, batch, batch_nb, *dataloader_ix): # validationのミニバッチに対して行う処理 """ (batch) -> (dict or OrderedDict) """ # X: size (n_batch, max_len) X = batch.text[0] # T: size (n_batch, 9) (一旦one-hot vector化する) T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1) # T: size (n_batch) (正解クラスの番号のみ保持) T = torch.argmax(T, dim=1) # GPU使用中ならX, Tを CPU -> GPU に移動させる X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device) T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long() # 各クラスに対する対数尤度: size (n_batch, 9) logPY = self.forward(X) # 損失関数: size (n_batch) loss = self.lossfunc(logPY, T) # 推測したクラス: size (n_batch) Y = torch.argmax(logPY, dim=1).long().detach() progress_bar = {'loss':loss} log = {'loss':loss} returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log} return returns def test_step(self, batch, batch_nb, *dataloader_ix): # testのミニバッチに対して行う処理 """ (batch) -> (dict or OrderedDict) """ # X: size (n_batch, max_len) X = batch.text[0] # T: size (n_batch, 9) (一旦one-hot vector化する) T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1) # T: size (n_batch) (正解クラスの番号のみ保持) T = torch.argmax(T, dim=1) # GPU使用中ならX, Tを CPU -> GPU に移動させる X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device) T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long() # 各クラスに対する対数尤度: size (n_batch, 9) logPY = self.forward(X) # 損失関数: size (n_batch) loss = self.lossfunc(logPY, T) # 推測したクラス: size (n_batch) Y = torch.argmax(logPY, dim=1).long().detach() progress_bar = {'loss':loss} log = {'loss':loss} returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log} return returns def training_end(self, outputs): # trainのミニバッチ1個が終わったときの結果に対する処理 """ outputs(dict) -> loss(dict or OrderedDict) # Caution: key must exactly be 'loss'. """ loss = torch.mean(outputs['loss']) progress_bar = {'loss':loss} log = {'loss':loss} returns = {'loss':loss, 'progress_bar':progress_bar, 'log':log} return returns def validation_end(self, outputs): # validationのミニバッチ全部が終わったときの結果に対する処理 """ For single dataloader: outputs(list of dict) -> (dict or OrderedDict) For multiple dataloaders: outputs(list of (list of dict)) -> (dict or OrderedDict) """ # 全データに対する損失関数 loss = torch.mean(torch.cat([output['loss'] for output in outputs])) # 全データに対する精度 acc = torch.mean(torch.cat([(output['label'] == output['pred']) * 1.0 for output in outputs])) progress_bar = {'val_loss':loss, 'val_acc':acc} log = {'val_loss':loss, 'val_acc':acc} returns = {'val_loss':loss, 'progress_bar':progress_bar, 'log':log} return returns def test_end(self, outputs): # testのミニバッチ全部が終わったときの結果に対する処理 """ For single dataloader: outputs(list of dict) -> (dict or OrderedDict) For multiple dataloaders: outputs(list of (list of dict)) -> (dict or OrderedDict) """ # 全データに対する損失関数 loss = torch.mean(torch.cat([output['loss'] for output in outputs])) # 全データに対する精度 acc = torch.mean(torch.cat([(output['label'] == output['pred']) * 1.0 for output in outputs])) progress_bar = {'test_loss':loss, 'test_acc':acc} log = {'test_loss':loss, 'test_acc':acc} returns = {'test_loss':loss, 'progress_bar':progress_bar, 'log':log} return returns def configure_optimizers(self): # Optimizer, schedulerを指定する # ここでは学習率2e-5でスタートし, 3, 5epoch目でそれぞれ学習率を0.1倍する optimizer = optim.Adam(self.parameters(), lr=2e-5) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 5], gamma=0.1) return [optimizer], [scheduler] @pl.data_loader def train_dataloader(self): # torch.utils.data.DataLoader を返させる # torchtext.data.Iterator でも可 return dl_train @pl.data_loader def val_dataloader(self): # torch.utils.data.DataLoader を返させる # torchtext.data.Iterator でも可 return dl_val @pl.data_loader def test_dataloader(self): # torch.utils.data.DataLoader を返させる # torchtext.data.Iterator でも可 return dl_test
これでもコードはかなり長く見えますが, PyTorch-Lightningを使わないと体感的にはこの倍くらいの長さになります!
クラスが定義できたら, コンストラクタでモデルを作成しましょう。
# モデルインスタンスを作成 model = LivedoorClassifier(bertmodel) # モデルをGPUに移す device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu' model.to(device)
続いて, 「Epoch数」「Early stoppingするかどうか」「ログはどう残すか」など, 学習そのものよりも一歩抽象度の高い, 実験そのものに関するハイパーパラメーターを定義していきましょう。
# Validation lossが3回続けて上昇したら学習をストップさせる early_stop_callback = EarlyStopping(monitor='val_loss', patience=3, mode='min') # ハイパーパラメーターをTrainerに与える trainer = Trainer( early_stop_callback=early_stop_callback, show_progress_bar=True, log_gpu_memory='all', max_nb_epochs=20 )
あとはtrainer.fit()
と書くだけで学習がスタートします!
trainer.fit(model)
なお, 初回はtqdm
モジュールに関するエラーが出て学習がスタートしない場合がありますが, その場合はランタイムを一旦再起動してコードを実行し直してください。
学習が完了したら, 次に
trainer.test()
と書くことでtest用データでの推論が走ります。
3-3. 結果
Test set での Accuracy は 94.71%となりました.
4. おわりに
本記事で試した内容は以下のとおりです:
- TensorFlow用学習済みモデルをPyTorch用に変換した
- Torchtextを使って前処理のコードを簡略化した
- PyTorch-Lightningを使って学習のコードを簡略化した
だいぶ欲張った内容となりましたが, 少しでも初心者の方の参考になれば幸いです。