Juman++, SentencePiece, BERT tokenizerの分かち書きを同じコードで書くための抽象クラス

0. 動機

自然言語処理のためには, 入力文を分かち書きし, 各トークンを数値に変換しなくてはなりません。

分かち書きのためのモジュールは Janome(MeCab), Juman++, SentencePiece, BERT tokenizer など色々提供されています。

しかし, 厄介なことに, これらは

など, 分かち書きをどの粒度で行うのか, 処理をどの段階まで行うのかの点でバラバラです。ややこしいですね。

さらに後述しますが, BERTの登場により, 1つの前処理なのに2つ以上の分かち書き器を組み合わせて使わなければならない場面も登場してきました。ますますややこしいですね。

そんなときに, どんな分かち書き器に対しても共通のコードで分かち書きをしたい!というのが動機です。

(もう一つの動機は, 単に私がPythonで抽象クラスを書く練習がしたかったというだけ)

注: 自然言語処理で tokenize, tokenizer という場合, 分かち書きを指すのか, IDへの変換も含めて意味するのかは文脈により曖昧な印象があります。このため, 一般的な表記ではないかもしれませんが本記事では分かち書き, 分かち書き器のように呼びます。

1. 基底クラスの定義

あらゆる分かち書き器に対応するための抽象クラスを定義しましょう。

まず, 使いたい分かち書き器が搭載している機能のうち

がどんなものであるかを指示するための空のメソッドを用意しておきます(@abstractmethodでデコレートされているもの)。

これらの空のメソッドはクラス継承時に必ずoverrideしなければなりませんが, そのoverrideさえきちんと行えば, 実際の分かち書きやBERT入力形式への変換作業はあらかじめ基底クラスに定義してあるため, あらためて考えなくてもよい設計にしています。

from abc import ABCMeta, abstractmethod

class BaseTextProcessorForDataField(object):
    """
                raw text     -┐[1]
          [2]┌- cleaned text <┘
             └> [words]      -┐[3]
          [4]┌- [wordpieces] <┘
             └> [token ids]  -┐[5]
                BERT input   <┘

            [1]: text cleaning
            [2]: base tokenization
            [3]: wordpiece tokenization
            [4]: convert to token ids
            [5]: convert to BERT input
    """

    __metaclass__ = ABCMeta

    def __init__(self):
        self.unk_id = -1
        self.cls_id = -1
        self.sep_id = -1
        self.pad_id = -1
        self.fix_length = 0
        self._enable_base_punctuation = True
        self._enable_wordpiece_punctuation = True
        self._enable_punctuation = True
        self._enable_tokenization = True
        self.errmsg = '{} is not offered in the tokenizer in use in the first place.'

    # MeCab, Juman++, SentencePiece, BertTokenizer等の処理工程の違いに対応する準備
    # 抽象クラスのメソッドを用意する
    # 抽象クラス継承時には必ずoverrideする必要がある
    @abstractmethod
    def clean(self, text):
        # [1]
        pass
    @abstractmethod
    def punctuate_base(self, text):
        # [2]
        pass
    @abstractmethod
    def punctuate_wordpiece(self, word):
        # [3]
        pass
    @abstractmethod
    def punctuate(self, text):
        # [2] + [3]
        pass
    @abstractmethod
    def convert_token_to_id(self, token):
        # [4]
        pass
    @abstractmethod
    def tokenize_at_once(self, text):
        # [2] + [3] + [4]
        pass

    # 以下は実際に分かち書きするためのメソッド(override不要)
    def to_words(self, text):
        """
        Apply only basic tokenization to text.
        ------------
            raw-text
               | <- (this function)
            cleaned-text
               | <- (this function)
            [tokens_words]
               |
            [tokens_wordpieces]
               |
            [token-ids]
               |
            [cls-id token-ids pad-ids sep-ids]
        ------------
        Inputs: text(str) - raw passage
        Outs: (list(str)) - passage split into tokens
        """
        # [1] + [2]
        if self._enable_base_punctuation:
            # Base puncuation を行う
            return self.punctuate_base(self.clean(text))
        else:
            print(self.errmsg.format('Base punctuation'))

    def to_wordpieces(self, text):
        """
        Apply basic tokenization & wordpiece tokenization to text.
        ------------
            raw-text
               | <- (this function)
            cleaned-text
               | <- (this function)
            [tokens_words]
               | <- (this function)
            [tokens_wordpieces]
               | <- (this function)
            [token-ids]
               |
            [cls-id token-ids pad-ids sep-ids]
        ------------
        Inputs: text(str) - raw passage
        Outs: (list(str)) - passage split into tokens
        """
        # [1] + [2] + [3]
        if self._enable_punctuation:
            # Base puncuation と Wordpiece Punctuation を行う
            # 分かち書き器が[2]+[3]を単一メソッドで提供している場合
            return self.punctuate(self.clean(text))
        elif self._enable_base_punctuation and self._enable_wordpiece_punctuation:
            # Base puncuation と Wordpiece Punctuation を行う
            # 分かち書き器が[2],[3]を別メソッドで提供している場合
            wordpieces = []
            for word in self.puncuate_base(self.clean(text)):
                wordpieces += self.punctuate_wordpiece(word)
            return wordpieces
        elif self._enable_base_punctuation:
            # Base puncuation のみ行う
            # 分かち書き器がWordpiece Punctuationに対応していない場合
            return self.to_words(text)

    def to_token_ids(self, text):
        """
        Apply cleaning, punctuation and id-conversion to text. 
        ------------
            raw-text
               | <- (this function)
            cleaned-text
               | <- (this function)
            [tokens_words]
               | <- (this function)
            [tokens_wordpieces]
               | <- (this function)
            [token-ids]
               | <- (this function)
            [cls-id token-ids pad-ids sep-ids]
        ------------
        Inputs: text(str) - raw passage
        Outs: (list(int)) - list of token ids
        """
        # [1] + [2] + [3] + [4]
        if self._enable_tokenization:
            # 分かち書き器が[2]+[3]+[4]を単一メソッドで提供している場合
            return self.tokenize_at_once(self.clean(text))
        else:
            # 分かち書き器が[2]+[3]+[4]を単一メソッドで提供していない場合
            return [ self.convert_token_to_id(token) for token in self.to_wordpieces(text) ]

    def to_bert_input(self, text):
        """
        Obtain BERT style token ids from text.
        ------------
            raw-text
               | <- (this function)
            cleaned-text
               | <- (this function)
            [tokens_words]
               | <- (this function)
            [tokens_wordpieces]
               | <- (this function)
            [token-ids]
               | <- (this function)
            [cls-id token-ids pad-ids sep-ids]
        ------------
        Inputs: text(str) - raw passage
        Outs: (list(int)) - list of token ids
        """
        # [1] + [2] + [3] + [4] + [5]
        # [CLS] <入力文のトークンID列> [PAD] ... [PAD] [SEP] 形式のID列にして返す
        padded = [self.cls_id] + self.to_token_ids(text) + [self.pad_id] * self.fix_length
        return padded[:self.fix_length-1] + [self.sep_id]

2. 実践

これで, どのような仕様の分かち書き器を使ったとしても

の好きな段階まで処理を行うことができ, コードを共通化することができます!
この抽象クラスを継承して, さまざまな分かち書きに対するコードを共通化させていきましょう。

2-1. Juman++ & BERT の場合

一般に, 訓練済みモデルをfine-tuningして使いたい場合, 分かち書きは訓練済みモデルの事前学習に使われたのと同じ方法で行われなければなりません。

これは, 京大黒橋研から公開されている訓練済み日本語BERTモデルを使いたい場合に特に問題になってきます。

京大黒橋研モデルはどのような分かち書きで事前学習されているのでしょうか? 公式HPによると

  • Juman++を用いて形態素に分割し,
  • その後, BERT付属のtokeizerを用いて形態素をサブワードへとさらに分割する

と説明されています。

つまり, 京大黒橋研モデルは事前学習時に Juman++とBERT wordpiece tokenizer を組み合わせた分かち書きを使っているため, fine-tuning時にも同じように Juman++とBERT wordpiece tokenizer を組み合わせなければなりません。

そこで, 先ほどの抽象クラスを継承した JumanppBERTTextProcessor クラスを定義していきましょう。
クラスの定義時に行うのは, 抽象クラスのメソッドのoverrideと, [UNK], [CLS], [SEP], [PAD]にあたるID番号を与えることです。

なお, PyKNPとHugging Face Transformersはすでにインストールされているものとします。

from pyknp import Juman
from transformers import BertTokenizer

PATH_VOCAB = './vocab.txt'    # 1行に1語彙が書かれたtxtファイル
jpp = Juman()
tokenizer = BertTokenizer(PATH_VOCAB, do_lower_case=False, do_basic_tokenize=False)

class JumanppBERTTextProcessor(BaseTextProcessorForDataField):
    """
    JumanppBERTTextProcessor(jppmodel, bertwordpiecetokenizer, fix_length) -> object

    Inputs
    ------
    jppmodel(pyknp.Juman):
        Juman++ tokenizer.
    bertwordpiecetokenizer(transformers.BertTokenizer):
        BERT tokenizer offered by huggingface.co transformers.
        This must be initialized with do_basic_tokenize=False.
    fix_length(int):
        Desired length of resulting BERT input including [CLS], [PAD] and [SEP].
        Longer sentences will be truncated.
        Shorter sentences will be padded.
        """
    def __init__(self, jppmodel, bertwordpiecetokenizer, fix_length):
        self.unk_id = bertwordpiecetokenizer.vocab['[UNK]']
        self.cls_id = bertwordpiecetokenizer.vocab['[CLS]']
        self.sep_id = bertwordpiecetokenizer.vocab['[SEP]']
        self.pad_id = bertwordpiecetokenizer.vocab['[PAD]']
        self.fix_length = fix_length
        self.enable_base_punctuation = True
        self.enable_wordpiece_punctuation = True
        self.enable_punctuation = False
        self.enable_tokenization = False

    # abstractメソッドのoverride
    def clean(self, text):
        return text
    def punctuate_base(self, text):
        return [mrph.midasi for mrph in jppmodel.analysis(self.clean(text)).mrph_list()]
    def punctuate_wordpiece(self, word):
        return bertwordpiecetokenizer.tokenize(word)
    def punctuate(self, text):
        pass
    def convert_token_to_id(self, token):
        try:
            return bertwordpiecetokenizer.vocab[token]
        except KeyError:
            return self.unk_id
    def tokenize_at_once(self, text):
        pass

さて, これで「Juman++で形態素へ, 続いてBERT tokenizerでサブワードへ分かち書きする」処理を短いコードで書くことができます。

jbp = JumanppBERTTextProcessor(jpp, bertwordpiecetokenizer, fix_length=256)
passage = '胸部単純CTを撮像しました。'

jbp.to_words(passage)
# ['胸部', '単純', 'CT', 'を', '撮像', 'し', 'ました']
jbp.to_wordpieces(passage)
# ['胸部', '単純', '[UNK]', 'を', '撮', '##像', 'し', 'ました']
jbp.to_token_ids(passage)
# [15166, 8420, 1, 10, 17015, 55083, 31, 4561]
jbp.to_bert_input(passage)
# [2, 15166, 8420, 1, 10, 17015, 55083, 31, 4561, 0, ..., 0, 3]

2-2. SentencePieceの場合

SentencePiece の場合はもっと単純です。
日本語の文法的な概念としての "単語" にこだわらない設計になっており, すべてがサブワードとして扱われます。
したがってBase tokenization と Wordpiece tokenization の区別もありません。

import sentencepiece as sp
SPM_MODEL_PATH = ''    # 訓練済みSentencePieceモデルのパス
spm = SentencePieceProcessor()
spm.Load(SPM_MODEL_PATH)

class SentencePieceTextProcessor(BaseTextProcessorForDataField):
    def __init__(self, spmodel, fix_length):
        super().__init__()
        self.unk_id = spmodel.PieceToId('<unk>')
        self.cls_id = spmodel.PieceToId('[CLS]')
        self.sep_id = spmodel.PieceToId('[SEP]')
        self.pad_id = spmodel.PieceToId('[PAD]')
        self.fix_length = fix_length
        self._enable_base_punctuation = False
        self._enable_wordpiece_punctuation = False
        self._enable_punctuation = True
        self._enable_tokenization = True
        self.spmodel = spmodel

    def clean(self, text):
        # ここではstopword除去などは行わない
        return text

    def punctuate_base(self, text):
        pass

    def punctuate_wordpiece(self, text):
        pass

    def punctuate(self, text):
        return self.spmodel.EncodeAsPieces(text)

    def convert_token_to_id(self, token):
        return self.spmodel.PieceToId(token)

    def tokenize_at_once(self, text):
        return self.spmodel.EncodeAsIds(text)

実際に SentencePiece で入力文を BERT入力形式に変換すると以下のようになります。

stp = SentencePieceTextProcessor(spm, fix_length=256)
passage = '胸部単純CTを撮像しました'

stp.to_words(passage)
# ['▁', '胸部', '単純', 'C', 'T', 'を', '撮', '像', 'しま', 'した']
stp.to_wordpieces(passage)
# ['▁', '胸部', '単純', 'C', 'T', 'を', '撮', '像', 'しま', 'した']
stp.to_token_ids(passage)
# [9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29]
stp.to_bert_input(passage)
# [4, 9, 20854, 9947, 4167, 0, 18, 10032, 1164, 3899, 29, 3, ..., 3, 5]

3. 終わりに

どんな分かち書き器に対してもなるべく同じコードで分かち書きができるような抽象クラスを作りました。

利点

  • 単語IDの辞書が分かち書き器自体に保持されている場合にも, 外部にある場合にも, 全く同じコードで分かち書きできる
  • 京大黒橋研日本語BERTのように2つの分かち書き器を組み合わせなければならない場面でもコードが短くなる

欠点

  • 簡単なことを難しく書いている印象は否めない