パスワードを忘れた? アカウント作成
16491725 journal
人工知能

yasuokaの日記: ku-nlp/deberta-v2-base-japaneseのトークナイザをJuman++に繋ぎつつJCommonSenseQAでファインチューニング 1

日記 by yasuoka

1月20日の日記の手法をもとに、ku-nlp/deberta-v2-base-japaneseJGLUEのJCommonSenseQAでファインチューニングしてみた。Google Colaboratory (GPU版)だと、こんな感じ。

!test -d jumanpp-2.0.0-rc3 || curl -L https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc3/jumanpp-2.0.0-rc3.tar.xz | tar xJf -
!test -x /usr/local/bin/jumanpp || ( mkdir jumanpp-2.0.0-rc3/build && cd jumanpp-2.0.0-rc3/build && cmake .. -DCMAKE_BUILD_TYPE=Release && make install )
!test -d transformers-4.26.0 || git clone -b v4.26.0 --depth=1 https://github.com/huggingface/transformers transformers-4.26.0
!test -d JGLUE || ( git clone --depth=1 https://github.com/yahoojapan/JGLUE && cat JGLUE/fine-tuning/patch/transformers-4.9.2_jglue-1.1.0.patch | ( cd transformers-4.26.0 && patch -p1 ) )
!cd transformers-4.26.0 && pip install .
!pip install -r transformers-4.26.0/examples/pytorch/text-classification/requirements.txt
!pip install protobuf==3.19.1 tensorboard pytextspan rhoknp
from transformers import DebertaV2TokenizerFast,AutoModelForMaskedLM
tkz=DebertaV2TokenizerFast.from_pretrained("ku-nlp/deberta-v2-base-japanese")
mdl=AutoModelForMaskedLM.from_pretrained("ku-nlp/deberta-v2-base-japanese")
tkz.__class__.__name__="JumanppDebertaV2TokenizerFast"
tkz.init_kwargs["auto_map"]={"AutoTokenizer":[None,"tokenizer.JumanppDebertaV2TokenizerFast"]}
tkz.save_pretrained("deberta-v2-base-japanese")
mdl.save_pretrained("deberta-v2-base-japanese")
s='''#! /usr/bin/python3
from transformers import DebertaV2TokenizerFast
from transformers.models.bert_japanese.tokenization_bert_japanese import JumanppTokenizer
class JumanppPreTokenizer(JumanppTokenizer):
  def jumanpp_split(self,i,normalized_string):
    import textspan
    t=str(normalized_string)
    k=self.tokenize(t)
    return [normalized_string[s:e] for c in textspan.get_original_spans(k,t) for s,e in c]
  def pre_tokenize(self,pretok):
    pretok.split(self.jumanpp_split)
class JumanppDebertaV2TokenizerFast(DebertaV2TokenizerFast):
  def __init__(self,**kwargs):
    from tokenizers.pre_tokenizers import PreTokenizer,Metaspace,Sequence
    super().__init__(**kwargs)
    self._tokenizer.pre_tokenizer=Sequence([PreTokenizer.custom(JumanppPreTokenizer()),Metaspace()])
  def save_pretrained(self,save_directory,**kwargs):
    import os
    import shutil
    from tokenizers.pre_tokenizers import PreTokenizer,Metaspace,Sequence
    self._auto_map={"AutoTokenizer":[None,"tokenizer.JumanppDebertaV2TokenizerFast"]}
    self._tokenizer.pre_tokenizer=Metaspace()
    super().save_pretrained(save_directory,**kwargs)
    self._tokenizer.pre_tokenizer=Sequence([PreTokenizer.custom(JumanppPreTokenizer()),Metaspace()])
    shutil.copy(os.path.abspath(__file__),os.path.join(save_directory,"tokenizer.py"))'''
with open("deberta-v2-base-japanese/tokenizer.py","w",encoding="utf-8") as w:
  print(s,file=w)
f="transformers-4.26.0/examples/pytorch/multiple-choice/run_swag.py"
!if fgrep trust_remote_code {f} ; then : ; else ( echo '%s/use_fast=.*,/& trust_remote_code=True,/' ; echo wq ) | ex -s {f} ; fi
!python {f} --model_name_or_path deberta-v2-base-japanese --do_train --do_eval --do_predict --max_seq_length 64 --per_device_train_batch_size 12 --learning_rate 5e-05 --num_train_epochs 4 --output_dir ./deberta-v2-base-japanese-jcommonsenseqa --overwrite_output_dir --train_file JGLUE/datasets/jcommonsenseqa-v1.1/train-v1.1.json --validation_file JGLUE/datasets/jcommonsenseqa-v1.1/valid-v1.1.json --test_file JGLUE/datasets/jcommonsenseqa-v1.1/valid-v1.1.json --use_fast_tokenizer True --evaluation_strategy epoch --warmup_ratio 0.1

auto_mapとtrust_remote_code=Trueがトリッキーだが、Juman++のインストールも含め、GPU版なら30分ほどでdeberta-v2-base-japanese-jcommonsenseqaが出来上がる。私(安岡孝一)の手元では、以下の「eval metrics」が出力された。

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =     0.8606
  eval_loss               =     0.6361
  eval_runtime            = 0:00:12.16
  eval_samples            =       1119
  eval_samples_per_second =     91.989
  eval_steps_per_second   =     11.509

JCommonSenseQAが0.8606なので、私が以前作ったdeberta-base-japanese-wikipediaより、断然高い。ただ、これ、Juman++だけじゃなくて、rhoknpもpytextspanも必要なので、さて、どうしたものかなぁ。

最新の日記

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

アレゲは一日にしてならず -- アレゲ研究家

読み込み中...