パスワードを忘れた? アカウント作成
15749766 journal
中国

yasuokaの日記: Bellman-FordはEvaHan 2022 Bakeoffの夢を見るか 2

日記 by yasuoka

一昨昨日の日記を眺めていて、「B-」「I-」ラベルにおける矛盾解消がちょっと雑なのが気になった。そこで、Bellman-Fordをlogits向けに変形して、挟み込んでみることにした。Google Colaboratory (GPU)だと、こんな感じ。

pretrained_model="SIKU-BERT/sikuroberta"
!pip install transformers
import os
url="https://github.com/CIRCSE/LT4HALA"
d=os.path.basename(url)
!test -d $d || git clone --depth=1 $url
!cp $d/2022/data_and_doc/EvaHan*.txt $d/2022/data_and_doc/*EvaHan*.py .
!sed '1s/^.//' EvaHan_testa_raw.txt | tr -d '\015' > testa.txt
!sed '1s/^.//' EvaHan_testb_raw.txt | tr -d '\015' > testb.txt
!test -f zuozhuan_train_utf8.txt || unzip $d/2022/data_and_doc/zuozhuan_train_utf8.zip
!sed '1s/^.//' zuozhuan_train_utf8.txt | tr -d '\015' | nawk '{{gsub(/。\/w/,"。/w\n");print}}' | egrep -v '^ *$' > train.txt
class EvaHanDataset(object):
  def __init__(self,file,tokenizer):
    self.ids,self.pos=[],[]
    label,cls,sep=set(),tokenizer.cls_token_id,tokenizer.sep_token_id
    with open(file,"r",encoding="utf-8") as r:
      for t in r:
        w,p=[k.split("/") for k in t.split()],[]
        v=tokenizer([k[0] for k in w],add_special_tokens=False)["input_ids"]
        for x,y in zip(v,w):
          if len(y)==1:
            y.append("w")
          if len(x)==1:
            p.append(y[1])
          elif len(x)>1:
            p.extend(["B-"+y[1]]+["I-"+y[1]]*(len(x)-1))
        self.ids.append([cls]+sum(v,[])+[sep])
        self.pos.append(["w"]+p+["w"])
        label=set(sum([self.pos[-1],list(label)],[]))
    self.label2id={l:i for i,l in enumerate(sorted(label))}
  __len__=lambda self:len(self.ids)
  __getitem__=lambda self,i:{"input_ids":self.ids[i],"labels":[self.label2id[t] for t in self.pos[i]]}
import torch,numpy
from transformers import AutoTokenizer,AutoConfig,AutoModel,AutoModelForTokenClassification,DataCollatorForTokenClassification,TrainingArguments,Trainer
tkz=AutoTokenizer.from_pretrained(pretrained_model)
mdl=AutoModel.from_pretrained(pretrained_model)
dir="."+pretrained_model.replace("/",".")
mdl.save_pretrained(dir+"/tmp-model")
trainDS=EvaHanDataset("train.txt",tkz)
cfg=AutoConfig.from_pretrained(dir+"/tmp-model",num_labels=len(trainDS.label2id),label2id=trainDS.label2id,id2label={i:l for l,i in trainDS.label2id.items()})
arg=TrainingArguments(per_device_train_batch_size=32,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2,save_strategy="epoch")
trn=Trainer(model=AutoModelForTokenClassification.from_pretrained(dir+"/tmp-model",config=cfg),args=arg,train_dataset=trainDS,data_collator=DataCollatorForTokenClassification(tkz))
trn.train()
trn.save_model(dir+"/evahan2022-model")
tkz.save_pretrained(dir+"/evahan2022-model")
mdl=AutoModelForTokenClassification.from_pretrained(dir+"/evahan2022-model")
idl=mdl.config.id2label
mtx=numpy.full((len(idl),len(idl)),numpy.nan)
d=numpy.array([numpy.nan if idl[i].startswith("I-") else 0 for i in range(len(idl))])
for i in range(len(idl)):
  if idl[i].startswith("B-"):
    mtx[i,mdl.config.label2id["I-"+idl[i][2:]]]=0
  else:
    mtx[i]=d
    if idl[i].startswith("I-"):
      mtx[i,i]=0
for f in ["testa","testb"]:
  with open(f+".txt","r",encoding="utf-8") as r:
    u,e=[],[]
    for s in r:
      t=s.strip().split("。")
      w=[j+"。" if i<len(t)-1 else j for i,j in enumerate(t) if j!=""]
      if len(w)==0:
        e[-1]=e[-1]+"\n"
      else:
        u.extend(w)
        e.extend([" "]*(len(w)-1)+["\n"])
  with open(f+dir+".txt","w",encoding="utf-8") as w:
    with torch.no_grad():
      for s,z in zip(u,e):
        v=tkz(s,return_offsets_mapping=True)
        m=mdl(torch.tensor([v["input_ids"]])).logits[0].numpy()
        for i in range(m.shape[0]-1,0,-1):
          m[i-1]+=numpy.nanmax(m[i]+mtx,axis=1)
        p=[numpy.nanargmax(m[0])]
        for i in range(1,m.shape[0])):
          p.append(numpy.nanargmax(m[i]+mtx[p[-1]]))
        d=[[idl[q],s[t[0]:t[1]]] for q,t in zip(p,v["offset_mapping"]) if t[0]<t[1]]
        for i in range(len(d)-1,0,-1):
          if d[i][0].startswith("I-"):
            if d[i-1][0].startswith("B-"):
              e=d.pop(i)
              d[i-1]=[d[i-1][0][2:],d[i-1][1]+e[1]]
            elif d[i-1][0].startswith("I-"):
              e=d.pop(i)
              d[i-1][1]=d[i-1][1]+e[1]
        for i in range(len(d)):
          if d[i][0].startswith("B-") or d[i][0].startswith("I-"):
            d[i][0]=d[i][0][2:]
        print(" ".join(t[1]+"/"+t[0] for t in d),file=w,end=z)
!python eval_EvaHan_2022_FINAL.py testa{dir}.txt EvaHan_testa_gold.txt
!python eval_EvaHan_2022_FINAL.py testb{dir}.txt EvaHan_testb_gold.txt

私(安岡孝一)の手元では、15分ほどで以下の結果が得られた。

The result of testa.SIKU-BERT.sikuroberta.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 95.7918 | 96.6976 | 96.2426 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 91.5660 | 92.4318 | 91.9969 |
+-----------------+---------+---------+---------+

The result of testb.SIKU-BERT.sikuroberta.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 94.4153 | 91.3848 | 92.8753 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 88.6004 | 85.7565 | 87.1552 |
+-----------------+---------+---------+---------+

一昨昨日に較べると、「Pos tagging」における解析精度が上がっている。ちなみに1行目を「pretrained_model="KoichiYasuoka/bert-ancient-chinese-base-upos"」に変えたところ、私の手元では以下の結果になった。

The result of testa.KoichiYasuoka.bert-ancient-chinese-base-upos.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 95.9862 | 96.7403 | 96.3617 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 91.9335 | 92.6558 | 92.2933 |
+-----------------+---------+---------+---------+

The result of testb.KoichiYasuoka.bert-ancient-chinese-base-upos.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 94.8630 | 91.8269 | 93.3202 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 88.9929 | 86.1447 | 87.5457 |
+-----------------+---------+---------+---------+

あるいは1行目を「pretrained_model="Jihuai/bert-ancient-chinese"」に変えたところ、私の手元では以下の結果になった。

The result of testa.Jihuai.bert-ancient-chinese.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 96.1192 | 96.8504 | 96.4835 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 92.0233 | 92.7233 | 92.3720 |
+-----------------+---------+---------+---------+

The result of testb.Jihuai.bert-ancient-chinese.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 94.6381 | 91.4387 | 93.0109 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 88.5360 | 85.5429 | 87.0137 |
+-----------------+---------+---------+---------+

ざっと見た限りだと、Bellman-Fordが「Pos tagging」を改善するのは間違いなさそうだ。ただし、それでも、復旦大学のチームに今一歩で及んでいない。やっぱり世界の壁は厚いなあ。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
  • by murawaki (48618) on 2022年08月06日 9時15分 (#4302842) ホームページ
    言語は系列データなので、一般のグラフアルゴリズムよりも、系列データであることを利用した手法を使うことが多いです。 タグ付けで標準的なのは (linear chain) CRFです。 https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html [pytorch.org] https://pytorch-crf.readthedocs.io/en/stable/ [readthedocs.io] トークンごとの分類器が吐く logit を CRF における emission スコアとみなしたうえで、transition スコアを追加していることになります。 BIタギングなら、(品詞の異なり数 x 2) ^ 2 大きさの transition table を訓練データから学習します。 B-a + I-b のような不正な系列は訓練データに出てこないので、普通に学習するだけテスト時に出てこなくなります。 これはこれでオーバーキル感がありますが、確立された手法なので使うことがあるといったところです。
    • by yasuoka (21275) on 2022年08月07日 13時17分 (#4303259) 日記

      ありがとうございます。今回、復旦大学グループがCRF (Conditional Random Fields)を使っているので、そっちの方が精度あがるんだろうな、と思いつつも、アマノジャクなので別のアルゴリズムを試したくなってしまうのです。でも結局、今のところはCRFなのかなぁ…。

      親コメント
typodupeerror

ハッカーとクラッカーの違い。大してないと思います -- あるアレゲ

読み込み中...