パスワードを忘れた? アカウント作成
16493172 journal
人工知能

yasuokaの日記: ku-nlp/deberta-v2-base-japaneseのトークナイザをJuman++に繋ぎつつJSQuADでファインチューニング 2

日記 by yasuoka

昨日の日記の手法を応用して、ku-nlp/deberta-v2-base-japaneseJGLUEのJSQuADでファインチューニングしてみた。Google Colaboratory (GPU版)だと、こんな感じ。

!test -d jumanpp-2.0.0-rc3 || curl -L https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc3/jumanpp-2.0.0-rc3.tar.xz | tar xJf -
!test -x /usr/local/bin/jumanpp || ( mkdir jumanpp-2.0.0-rc3/build && cd jumanpp-2.0.0-rc3/build && cmake .. -DCMAKE_BUILD_TYPE=Release && make install )
!test -d transformers-4.26.0 || git clone -b v4.26.0 --depth=1 https://github.com/huggingface/transformers transformers-4.26.0
!test -d JGLUE || ( git clone --depth=1 https://github.com/yahoojapan/JGLUE && cat JGLUE/fine-tuning/patch/transformers-4.9.2_jglue-1.1.0.patch | ( cd transformers-4.26.0 && patch -p1 ) )
!cd transformers-4.26.0 && pip install .
!pip install -r transformers-4.26.0/examples/pytorch/text-classification/requirements.txt
!pip install protobuf==3.19.1 tensorboard pytextspan rhoknp
import json
from transformers import DebertaV2TokenizerFast,AutoModelForMaskedLM
tkz=DebertaV2TokenizerFast.from_pretrained("ku-nlp/deberta-v2-base-japanese")
mdl=AutoModelForMaskedLM.from_pretrained("ku-nlp/deberta-v2-base-japanese")
tkz.__class__.__name__="JumanppDebertaV2TokenizerFast"
tkz.init_kwargs["auto_map"]={"AutoTokenizer":[None,"tokenizer.JumanppDebertaV2TokenizerFast"]}
tkz.save_pretrained("deberta-v2-base-japanese")
mdl.save_pretrained("deberta-v2-base-japanese")
s='''#! /usr/bin/python3
from transformers import DebertaV2TokenizerFast
from transformers.models.bert_japanese.tokenization_bert_japanese import JumanppTokenizer
class JumanppPreTokenizer(JumanppTokenizer):
  def jumanpp_split(self,i,normalized_string):
    import textspan
    t=str(normalized_string)
    k=self.tokenize(t)
    return [normalized_string[s:e] for c in textspan.get_original_spans(k,t) for s,e in c]
  def pre_tokenize(self,pretok):
    pretok.split(self.jumanpp_split)
class JumanppDebertaV2TokenizerFast(DebertaV2TokenizerFast):
  def __init__(self,**kwargs):
    from tokenizers.pre_tokenizers import PreTokenizer,Metaspace,Sequence
    super().__init__(**kwargs)
    self._tokenizer.pre_tokenizer=Sequence([PreTokenizer.custom(JumanppPreTokenizer()),Metaspace()])
  def save_pretrained(self,save_directory,**kwargs):
    import os
    import shutil
    from tokenizers.pre_tokenizers import PreTokenizer,Metaspace,Sequence
    self._auto_map={"AutoTokenizer":[None,"tokenizer.JumanppDebertaV2TokenizerFast"]}
    self._tokenizer.pre_tokenizer=Metaspace()
    super().save_pretrained(save_directory,**kwargs)
    self._tokenizer.pre_tokenizer=Sequence([PreTokenizer.custom(JumanppPreTokenizer()),Metaspace()])
    shutil.copy(os.path.abspath(__file__),os.path.join(save_directory,"tokenizer.py"))'''
with open("deberta-v2-base-japanese/tokenizer.py","w",encoding="utf-8") as w:
  print(s,file=w)
for f in ["train-v1.1.json","valid-v1.1.json"]:
  with open("JGLUE/datasets/jsquad-v1.1/"+f,"r",encoding="utf-8") as r:
    j=json.load(r)
  u=[]
  for d in j["data"]:
    for p in d["paragraphs"]:
      for q in p["qas"]:
        u.append({"id":q["id"],"title":d["title"],"context":p["context"],"question":q["question"],"answers":{"text":[x["text"] for x in q["answers"]],"answer_start":[x["answer_start"] for x in q["answers"]]}})
  with open(f,"w",encoding="utf-8") as w:
    json.dump({"data":u},w,ensure_ascii=False,indent=2)
f="transformers-4.26.0/examples/pytorch/question-answering/run_qa.py"
!if fgrep trust_remote_code {f} ; then : ; else ( echo '%s/use_fast=.*,/& trust_remote_code=True,/' ; echo wq ) | ex -s {f} ; fi
!python {f} --model_name_or_path deberta-v2-base-japanese --do_train --do_eval --max_seq_length 384 --learning_rate 5e-05 --num_train_epochs 3 --per_device_train_batch_size 16 --per_device_eval_batch_size 16 --output_dir ./deberta-v2-base-japanese-jsquad --overwrite_output_dir --train_file train-v1.1.json --validation_file valid-v1.1.json --save_steps 5000 --warmup_ratio 0.1

GPU版でも5時間を要したものの、私(安岡孝一)の手元では以下の「eval metrics」が出力された。

***** eval metrics *****
  epoch                   =        3.0
  eval_exact_match        =      90.86
  eval_f1                 =    90.9027
  eval_runtime            = 0:00:55.75
  eval_samples            =       4450
  eval_samples_per_second =      79.82
  eval_steps_per_second   =      5.004

JSQuADが、EM/F1=0.9086/0.9090というのは、baseモデルとしては非常にイイセンだと思う。ku-nlp/deberta-v2-large-japaneseの方も見てみたい気がするが、さすがにGoogle Colaboratoryだと打ち切られちゃうかなぁ。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

目玉の数さえ十分あれば、どんなバグも深刻ではない -- Eric Raymond

読み込み中...