日本語BERTモデルをPyTorch用に変換してfine-tuningする with torchtext & pytorch-lightning

TL;DR

①TensorFlow版訓練済みモデルをPyTorch用に変換した
 (→方法だけ読みたい方はこちら)

②①をスムーズに使うための torchtext.data.Dataset を設計した

③PyTorch-Lightningを使ってコードを短くした

はじめに

日本語Wikipediaで事前学習されたBERTモデルとしては, 以下の2つが有名であり, 広く普及しています:

  • SentencePieceベースのモデル (Yohei Kikuta さん提供)
    • TensorFlow版
  • Juman++ベースのモデル (京大黒橋研提供)
    • TensorFlow版
    • PyTorch版(Hugging Face transformers準拠)

このうち, SentencePieceベースのものは現在TensorFlow版のみの提供となっており, PyTorch版は存在しません。
そのため, 私のようなPyTorchユーザーでがっくり肩を落とされた方は多いのではないでしょうか?

しかし決して諦めることはありません。
実は, ほんの少し工夫するだけでPyTorch版に変換することは可能です!
早速試していきましょう。

(本記事の手法を試すにあたり, kaggler-ja slackの皆さんには多くの助言をいただきました。この場を借りてお礼申し上げます)

環境
  • Google Colaboratory
    • Python 3.6.9
    • TensorFlow 1.15.0
    • PyTorch 1.3.1
    • Torchtext 0.3.1
    • PyTorch-Lightning 0.5.3.2

実践

0. 下準備

0-1. Yohei Kikutaさん版日本語BERTモデルの取得

f:id:radiology-nlp:20191213130202p:plain こちらに公開されているファイルを取得しておきます。

  • BERTモデルのCheckpoint
    • model.ckpt-1400000.index
    • model.ckpt-1400000.meta
    • model.ckpt-1400000.data-00000-of-00001
  • BERTモデルのメタグラフ形式 (今回は使用しません)
    • graph.pbtxt
  • SentencePieceモデル
    • wiki-ja.model
    • wiki-ja.vocab

ここではGoogle DriveMy Drive/NLP/bert_yoheikikutasan/ 直下に保存するものとします。

f:id:radiology-nlp:20191213130720p:plain

0-2. Google Driveのマウント

つづいてGoogle Colaboratoryに入り, 仮想マシンGoogle Driveをマウントします。

from google.colab import drive

# Google Driveをマウントする仮想マシン上のディレクトリ
DIR_DRIVE = './gdrive/'

# Google Drive上でのNotebook等の各種ファイルのパス
DIR_COLAB = DIR_DRIVE + 'My Drive/Colab Notebooks/'
DIR_PROJCET = DIR_COLAB + 'livedoor_classification/'

# Google Driveをマウント
drive.mount(DIR_DRIVE)

標準出力にしたがってアカウント認証と認証コードの入力を行い, マウントを完了させます。

1. 訓練済みBERTモデルの変換

1-0. 方針

f:id:radiology-nlp:20191213140247p:plain

つづいて訓練済みBERTモデルをTensorFlow用からPyTorch用に変換していきましょう。

PyTorch側でモデルの"ガワ"だけ作っておき, そこにTensorFlow用モデルの重み行列の中身を流し込むイメージです。
PyTorch用モデルの"ガワ"はゼロから設計はせず, PyTorch用BERT族の定番ライブラリ(Hugging Face Transformers)を利用します。

1-1. Hugging Face transformersの準備

Hugging Face Transformersをインストールし, モデルの枠組みをつくります。

!pip install transformers
from transformers import BertConfig, BertForPreTraining, BertTokenizer, BertModel

# configの用意 (語彙数は30522 -> 32000に修正しておく)
bertconfig = BertConfig.from_pretrained('bert-base-uncased')
bertconfig.vocab_size = 32000

# BERTモデルの"ガワ"の用意 (全パラメーターはランダムに初期化されている)
bertmodelforpretraining = BertForPreTraining(bertconfig)

1-2. TensorFlowモデル -> PyTorchモデルの変換

つづいてTensorFlow版BERTの重み行列を読み込み, PyTorch版モデルに読み込みましょう。
これはHugging Face Transformersのメソッド一発で簡単にできます。

なお, この工程でTensorFlowのcheckpointパスが必要になりますが, TensorFlowの文脈で「checkpointのパス」といった場合は .index, .meta, .data-XXXXX-of-YYYYY などの拡張子を除いた部分を指すことに注意が必要です。

DIR_BERT_KIKUTA = DIR_DRIVE + 'My Drive/NLP/bert_yoheikikutasan/'
BASE_CKPT = 'model.ckpt-1400000'    # 拡張子は含めない

# TensorFlowモデルの重み行列を読み込む (数分程度かかる場合がある)
bertmodelforpretraining.load_tf_weights(bertconfig, DIR_BERT_KIKUTA + BASE_CKPT)

# BERTの本体部分だけ取り出す
bertmodel = bertmodelforpretraining.bert

これで無事にPyTorch版日本語BERTを手に入れることができました!

2. Livedoorニュースコーパスでfine-tuningする

2-0. 方針

ここからは, 手に入れた日本語BERTモデルでLivedoorニュースコーパスに対する文書分類タスクを解いてみます。
まずは torchtext を用いて前処理の準備をしていきましょう。

f:id:radiology-nlp:20191213220602p:plain

torchtextは主に前処理とミニバッチの切り出しの省力化に特化したライブラリであり, 以下の(1)〜(4)をより少ないコード量で実現することができます。

  • (1) train, test用データを1行1サンプルのtsvに変換し, 同ディレクトリに別ファイルとして保存しておく
  • (2) 前処理を定義する (torchtext.data.Field)
  • (3) tsvの各カラムに Field を割り当て, 前処理を一括で実行 (torchtext.data.Dataset)
  • (4) 訓練時にミニバッチを自動的に取り出す (torchtext.data.Iterator)

2-1. tsvファイルの作成

詳しくはこちらを参照してください。 radiology-nlp.hatenablog.com

Livedoorニュースコーパスは9種類の記事からなるため, ここではtsvのカラムは左から順に元のテキストファイル名, 記事本文, ラベル9個のone-hot encoding の形式としました。

filename     article  dokujo_tsushin it-life-hack ... topic-news
hogehoge.txt fugafuga 1              0            ... 0

2-2. 訓練済みサブワード分割器の読み込み

つづいて, BERTの事前学習で使用されたのと同じサブワード分割器 (SentencePiece) のモデルを読み込み, 復元します。
まず SentencePiece をインストールしましょう。

!pip install sentencepiece

次に SentencePiece モデルを読み込みます。

import sentencepiece as sp
BASE_SPM = 'wiki-ja.model'
BASE_VOCAB = 'wiki-ja.vocab'

# 一旦空の SentencePiece モデルを作成
spm = sp.SentencePieceProcessor()

# 読み込み. 成功すると True が返る
spm.Load(DIR_BERT_KIKUTA + BASE_SPM)

2-3. 前処理の定義 (torchtext.data.Field)

前処理を torchtext.data.Field に定義していきましょう。
基本的に1つのFieldは1種類の前処理しか行うことができません。
このため, 2種類以上の前処理を行う場合はそれぞれについてFieldをつくっておく必要があります。

また, クラスラベル等のように "前処理を何も行わない" カラムに対しても, 前処理を何も行わないことを定義したFieldがやはり必要です。
したがって, 文書分類タスクの場合,少なくとも入力文用とクラスラベル用の2種類のFieldが必要になります。

では前処理をどのように定義するかというと, torchtext.data.Field のコンストラクタに callable を渡すことでその callable の内容を前処理として実行させることができます。

radiology-nlp.hatenablog.com ここでは↑の記事で作っておいた分かち書き用クラスを使います。
これを使うと, どのような分かち書き器に対しても同じコードで分かち書きができるようになります。

MAX_LEN = 256
stp = SentencePieceTextProcessor(spm, MAX_LEN)

passage = '吾輩は猫である。'

stp.to_wordpieces(passage)
# ['▁', '吾', '輩', 'は', '猫', 'である', '。']

stp.to_token_ids(passage)
# [9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29]

stp.to_bert_input(passage)
# [4, 9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29, 3, ..., 3, 5]

無事に前処理が定義できたところで, この前処理の機能を搭載したFieldをつくりましょう。
その他に, クラスラベル用に前処理を何も行わないFieldも定義します。

import torch
import torch.nn as nn
import torch.optim as optim
import torchtext

# 文を分かち書きしてBERT形式のID列に変換するField
field_text = torchtext.data.Field(sequential=True, use_vocab=False, batch_first=True, tokenize=stp.to_bert_input, include_lengths=True)
# 何もしないField
field_label = torchtext.data.Field(sequential=False, use_vocab=False)

2-4. 前処理の一括実行 (torchtext.data.Dataset)

これで, データセットのtsvファイルに対して前処理を一括で行う準備ができました。
ここまで来れば, データセットに前処理を施して torch.Tensor 形式に変換したものを短いコードで得ることができます。

まずは, データセットのtsvのどのカラムにどの前処理を割り当てたいかを指定しましょう。

import random
PATH_TSV = ''    # 3-1. で作成したtsvのパス

# tsvの各カラムに割り当てる名前とFieldを指定する
# (field_name, torchtext.data.Field) のタプルを容れたリスト
# torchtext.data.Fieldは反復使用してよい
# field_nameは反復使用不可
N_CLASS = 9
fields_livedoor = [('filename', None), ('text', field_text)] + [('label_{}'.format(i), field_label) for i in range(N_CLASS)]

続いてtorchtext.data.Datasetのコンストラクタを実行しましょう。
すると, ここでtsvファイルからデータが読み出され, 前処理が実行され, その結果がtorchtext.data.Datasetオブジェクトに格納されます。

# tsvファイルの各カラムに対応するFieldに割り当てられた前処理が実行され, 結果がtorch.Tensorで格納される
ds = torchtext.data.TabularDataset(path=PATH_TSV, format='tsv', skip_header=True, fields=fields_livedoor)

# train/val/testを分離
ds_train, ds_val, ds_test = ds.split(split_ratio=[0.8, 0.1, 0.1], random_state=random.seed(42))

2-5. torchtext.data.Iteratorの作成

次に, ミニバッチの切り出しを楽にしてくれる torchtext.data.Iterator をつくりましょう。
これも短いコードで書くことができます。

BATCH_SIZE = 32
dl_train = torchtext.data.Iterator(ds_train, batch_size=BATCH_SIZE, train=True)
dl_val = torchtext.data.Iterator(ds_val, batch_size=BATCH_SIZE, train=False, sort=False)
dl_test = torchtext.data.Iterator(ds_test, batch_size=BATCH_SIZE, train=False, sort=False)

イテレーターを回すとミニバッチを取り出すことができます。
ミニバッチはtorchtext.data.Batchオブジェクトで与えられており, ミニバッチの行列はこのオブジェクトのプロパティとして格納されています。

# 試しにdl_trainを回してみる
for batch in dl_train:
    # プロパティ名はtorchtext.data.Dataset定義時に与えたfield_nameと同じ
    print(batch.text)    # torch.LongTensor of size 32x256
    print(batch.label_0)    # torch.LongTensor of size 32
    print(batch.label_1)    # torch.LongTensor of size 32
    print(batch.label_2)    # torch.LongTensor of size 32
    print(batch.label_3)    # torch.LongTensor of size 32
    print(batch.label_4)    # torch.LongTensor of size 32
    print(batch.label_5)    # torch.LongTensor of size 32
    print(batch.label_6)    # torch.LongTensor of size 32
    print(batch.label_7)    # torch.LongTensor of size 32
    print(batch.label_8)    # torch.LongTensor of size 32

3. 学習する

3-1. BERTのfine-tuningするパラメーターを指定

続いてBERTのどの層のパラメータを固定し, どの層をfine-tuningするかを指定します。
ここではBERT Encoder layer 12層すべてと, それに続くPoolerをfine-tuningすることにしましょう。

# 一旦BERTの全レイヤーのfine-tuningを無効にする
for _, param in bertmodel.named_parameters():
    param.requires_grad = False

# Encoder layerのfine-tuningを有効化
for layer in bertmodel.encoder.layer:
    for _, param in layer.named_parameters():
        param.requires_grad = True

# Poolerのfine-tuningを有効化
for _, param in bertmodel.pooler.named_parameters():
    param.requires_grad = True

3-2. 学習用コード

ここまで来れば, あとは実際の学習のためのコードを書くだけです! あと一歩!
...と言いたいところですが, PyTorchではこの残りの一歩のためにかなり長いコードを書かなくてはいけません。
(実際に書いたことのある方はお分かりかと思います)

そこで, ここでは学習周りのコードを大幅に簡略化できるラッパーの1つ, PyTorch-Lightningを使っていきましょう。

PyTorch-Lightningとは何者で, 何が嬉しいのかは↓の記事に詳しいです。
qiita.com

まずLightningModuleを継承したクラスに, ネットワーク構造, 与えるデータ, 学習のプロセスを指定していきましょう。

生のPyTorchで書くと, でかいfor文を回したり, train/validation/testで細部を変えたりするのが大変ですが, PyTorch-Lightningはあらかじめ与えられた項目を穴埋めするだけでこれらがすべて完成する仕組みになっています。

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping


class LivedoorClassifier(pl.LightningModule):
    def __init__(self, bertmodel):
        # モデルの構造を記述
        super().__init__()
        self.bert_model = bertmodel
        self.bert_hidden_dim = self.bert_model.config.hidden_size   # 768
        self.affine = nn.Linear(self.bert_hidden_dim, 9)
        self.logsoftmax = nn.LogSoftmaxe(dim=1)
        self.postprocess = nn.Sequential(self.affine, self.logsoftmax)
        self.lossfunc = nn.NLLLoss(reduction='none')


    def forward(self, inputs, **kwargs):
        # モデルの推論を記述
        # model_output: size (n_batch, 9)
        model_output = self.postprocess(self.bert_model(inputs)[1])
        return model_output


    def training_step(self, batch, batch_nb):
        # trainのミニバッチに対して行う処理
        """
        (batch) -> (dict or OrderedDict)
        # Caution: key for loss function must exactly be 'loss'.
        """
        # X: size (n_batch, max_len)
        X = batch.text[0]
        # T: size (n_batch, 9) (一旦one-hot vector化する)
        T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1)
        # T: size (n_batch) (正解クラスの番号のみ保持)
        T = torch.argmax(T, dim=1)

        # GPU使用中ならX, Tを CPU -> GPU に移動させる
        X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device)
        T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long()

        # 各クラスに対する対数尤度: size (n_batch, 9)
        logPY = self.forward(X)
        # 損失関数: size (n_batch)
        loss = self.lossfunc(logPY, T)
        # 推測したクラス: size (n_batch)
        Y = torch.argmax(logPY, dim=1).long().detach()

        progress_bar = {'loss':loss}
        log = {'loss':loss}
        returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log}
        return returns
 

    def validation_step(self, batch, batch_nb, *dataloader_ix):
        # validationのミニバッチに対して行う処理
        """
        (batch) -> (dict or OrderedDict)
        """
        # X: size (n_batch, max_len)
        X = batch.text[0]
        # T: size (n_batch, 9) (一旦one-hot vector化する)
        T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1)
        # T: size (n_batch) (正解クラスの番号のみ保持)
        T = torch.argmax(T, dim=1)

        # GPU使用中ならX, Tを CPU -> GPU に移動させる
        X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device)
        T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long()

        # 各クラスに対する対数尤度: size (n_batch, 9)
        logPY = self.forward(X)
        # 損失関数: size (n_batch)
        loss = self.lossfunc(logPY, T)
        # 推測したクラス: size (n_batch)
        Y = torch.argmax(logPY, dim=1).long().detach()

        progress_bar = {'loss':loss}
        log = {'loss':loss}
        returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log}
        return returns


    def test_step(self, batch, batch_nb, *dataloader_ix):
        # testのミニバッチに対して行う処理
        """
        (batch) -> (dict or OrderedDict)
        """
        # X: size (n_batch, max_len)
        X = batch.text[0]
        # T: size (n_batch, 9) (一旦one-hot vector化する)
        T = torch.cat([getattr(batch, f'label_{i}').unsqueeze(0) for i in range(9)], dim=0).transpose(0,1)
        # T: size (n_batch) (正解クラスの番号のみ保持)
        T = torch.argmax(T, dim=1)

        # GPU使用中ならX, Tを CPU -> GPU に移動させる
        X = X.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device)
        T = T.to(self.bert_model.state_dict()['embeddings.word_embeddings.weight'].device).long()

        # 各クラスに対する対数尤度: size (n_batch, 9)
        logPY = self.forward(X)
        # 損失関数: size (n_batch)
        loss = self.lossfunc(logPY, T)
        # 推測したクラス: size (n_batch)
        Y = torch.argmax(logPY, dim=1).long().detach()

        progress_bar = {'loss':loss}
        log = {'loss':loss}
        returns = {'loss':loss, 'pred':Y, 'label':T, 'progress_bar':progress_bar, 'log':log}
        return returns


    def training_end(self, outputs):
        # trainのミニバッチ1個が終わったときの結果に対する処理
        """
        outputs(dict) -> loss(dict or OrderedDict)
        # Caution: key must exactly be 'loss'.
        """
        loss = torch.mean(outputs['loss'])

        progress_bar = {'loss':loss}
        log = {'loss':loss}
        returns = {'loss':loss, 'progress_bar':progress_bar, 'log':log}
        return returns


    def validation_end(self, outputs):
        # validationのミニバッチ全部が終わったときの結果に対する処理
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """        
        # 全データに対する損失関数
        loss = torch.mean(torch.cat([output['loss'] for output in outputs]))
        # 全データに対する精度
        acc = torch.mean(torch.cat([(output['label'] == output['pred']) * 1.0 for output in outputs]))

        progress_bar = {'val_loss':loss, 'val_acc':acc}
        log = {'val_loss':loss, 'val_acc':acc}
        returns = {'val_loss':loss, 'progress_bar':progress_bar, 'log':log}
        return returns


    def test_end(self, outputs):
        # testのミニバッチ全部が終わったときの結果に対する処理
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """
        # 全データに対する損失関数
        loss = torch.mean(torch.cat([output['loss'] for output in outputs]))
        # 全データに対する精度
        acc = torch.mean(torch.cat([(output['label'] == output['pred']) * 1.0 for output in outputs]))

        progress_bar = {'test_loss':loss, 'test_acc':acc}
        log = {'test_loss':loss, 'test_acc':acc}
        returns = {'test_loss':loss, 'progress_bar':progress_bar, 'log':log}
        return returns


    def configure_optimizers(self):
        # Optimizer, schedulerを指定する
        # ここでは学習率2e-5でスタートし, 3, 5epoch目でそれぞれ学習率を0.1倍する
        optimizer = optim.Adam(self.parameters(), lr=2e-5)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 5], gamma=0.1)
        return [optimizer], [scheduler]
    
    @pl.data_loader
    def train_dataloader(self):
        # torch.utils.data.DataLoader を返させる
        # torchtext.data.Iterator でも可
        return dl_train

    @pl.data_loader
    def val_dataloader(self):
        # torch.utils.data.DataLoader を返させる
        # torchtext.data.Iterator でも可
        return dl_val

    @pl.data_loader
    def test_dataloader(self):
        # torch.utils.data.DataLoader を返させる
        # torchtext.data.Iterator でも可
        return dl_test

これでもコードはかなり長く見えますが, PyTorch-Lightningを使わないと体感的にはこの倍くらいの長さになります!

クラスが定義できたら, コンストラクタでモデルを作成しましょう。

# モデルインスタンスを作成
model = LivedoorClassifier(bertmodel)

# モデルをGPUに移す
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
model.to(device)

続いて, 「Epoch数」「Early stoppingするかどうか」「ログはどう残すか」など, 学習そのものよりも一歩抽象度の高い, 実験そのものに関するハイパーパラメーターを定義していきましょう。

# Validation lossが3回続けて上昇したら学習をストップさせる
early_stop_callback = EarlyStopping(monitor='val_loss', patience=3, mode='min')

# ハイパーパラメーターをTrainerに与える
trainer = Trainer(
    early_stop_callback=early_stop_callback,
    show_progress_bar=True,
    log_gpu_memory='all',
    max_nb_epochs=20
)

あとはtrainer.fit()と書くだけで学習がスタートします!

trainer.fit(model)

なお, 初回はtqdmモジュールに関するエラーが出て学習がスタートしない場合がありますが, その場合はランタイムを一旦再起動してコードを実行し直してください。 f:id:radiology-nlp:20200118011824p:plain

学習が完了したら, 次に

trainer.test()

と書くことでtest用データでの推論が走ります。

3-3. 結果

Test set での Accuracy は 94.71%となりました.

4. おわりに

本記事で試した内容は以下のとおりです:

  • TensorFlow用学習済みモデルをPyTorch用に変換した
  • Torchtextを使って前処理のコードを簡略化した
  • PyTorch-Lightningを使って学習のコードを簡略化した

だいぶ欲張った内容となりましたが, 少しでも初心者の方の参考になれば幸いです。

参考資料

つくりながら学ぶ-PyTorchによる発展ディープラーニング