パスワードを忘れた? アカウント作成
15746834 journal
中国

yasuokaの日記: TransformersのTokenClassificationでEvaHan 2022 Bakeoff

日記 by yasuoka

遅ればせながら昨日の日記で挑戦したEvaHan 2022 Bakeoffだが、esuparの学習モジュールからUniversal Dependenciesガラミのアレコレを外して、TransformersのTrainerを使う形で書き直してみた。Google Colaboratory (GPU)だと、こんな感じ。

!pip install transformers
import os
url="https://github.com/CIRCSE/LT4HALA"
d=os.path.basename(url)
!test -d $d || git clone --depth=1 $url
!cp $d/2022/data_and_doc/EvaHan*.txt $d/2022/data_and_doc/*EvaHan*.py .
!sed '1s/^.//' EvaHan_testa_raw.txt | tr -d '\015' > testa.txt
!sed '1s/^.//' EvaHan_testb_raw.txt | tr -d '\015' > testb.txt
!test -f zuozhuan_train_utf8.txt || unzip $d/2022/data_and_doc/zuozhuan_train_utf8.zip
!sed '1s/^.//' zuozhuan_train_utf8.txt | tr -d '\015' | nawk '{{gsub(/。\/w/,"。/w\n");print}}' | egrep -v '^ *$' > train.txt
class EvaHanDataset(object):
  def __init__(self,file,tokenizer):
    self.ids,self.pos=[],[]
    label,cls,sep=set(),tokenizer.cls_token_id,tokenizer.sep_token_id
    with open(file,"r",encoding="utf-8") as r:
      for t in r:
        w,p=[k.split("/") for k in t.split()],[]
        v=tokenizer([k[0] for k in w],add_special_tokens=False)["input_ids"]
        for x,y in zip(v,w):
          if len(y)==1:
            y.append("w")
          if len(x)==1:
            p.append(y[1])
          elif len(x)>1:
            p.extend(["B-"+y[1]]+["I-"+y[1]]*(len(x)-1))
        self.ids.append([cls]+sum(v,[])+[sep])
        self.pos.append(["w"]+p+["w"])
        label=set(sum([self.pos[-1],list(label)],[]))
    self.label2id={l:i for i,l in enumerate(sorted(label))}
  __len__=lambda self:len(self.ids)
  __getitem__=lambda self,i:{"input_ids":self.ids[i],"labels":[self.label2id[t] for t in self.pos[i]]}
from transformers import AutoTokenizer,AutoConfig,AutoModelForTokenClassification,DataCollatorForTokenClassification,TrainingArguments,Trainer,pipeline
brt="SIKU-BERT/sikuroberta"
tkz=AutoTokenizer.from_pretrained(brt)
trainDS=EvaHanDataset("train.txt",tkz)
cfg=AutoConfig.from_pretrained(brt,num_labels=len(trainDS.label2id),label2id=trainDS.label2id,id2label={i:l for l,i in trainDS.label2id.items()})
arg=TrainingArguments(per_device_train_batch_size=32,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2,save_strategy="epoch")
trn=Trainer(model=AutoModelForTokenClassification.from_pretrained(brt,config=cfg),args=arg,train_dataset=trainDS,data_collator=DataCollatorForTokenClassification(tkz))
trn.train()
trn.save_model("roberta-han")
tkz.save_pretrained("roberta-han")
tagger=pipeline(task="ner",model="roberta-han",device=0)
for f in ["testa","testb"]:
  with open(f+".txt","r",encoding="utf-8") as r:
    with open(f+"_close.txt","w",encoding="utf-8") as w:
      for s in r:
        d=[]
        if s.strip()!="":
          t=s.split("。")
          u=[j+"。" if i<len(t)-1 else j for i,j in enumerate(t) if j!=""]
          v=tagger(u)
          for j,k in zip(u,v):
            d+=[[t["entity"],j[t["start"]:t["end"]]] for t in k]
        for i in range(len(d)-1,0,-1):
          if d[i][0].startswith("I-"):
            if d[i-1][0].startswith("B-"):
              e=d.pop(i)
              d[i-1]=[d[i-1][0][2:],d[i-1][1]+e[1]]
            elif d[i-1][0].startswith("I-"):
              e=d.pop(i)
              d[i-1][1]=d[i-1][1]+e[1]
        for i in range(len(d)):
          if d[i][0].startswith("B-") or d[i][0].startswith("I-"):
            d[i][0]=d[i][0][2:]
        print(" ".join(t[1]+"/"+t[0] for t in d),file=w)
!python eval_EvaHan_2022_FINAL.py testa_close.txt EvaHan_testa_gold.txt
!python eval_EvaHan_2022_FINAL.py testb_close.txt EvaHan_testb_gold.txt

ただし、EvaHan向けに品詞「w」を特別扱いしている。私(安岡孝一)の手元では、10分ほどで以下の結果が得られた。

The result of testa_close.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 95.3232 | 96.7260 | 96.0195 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 90.9441 | 92.2825 | 91.6084 |
+-----------------+---------+---------+---------+

The result of testb_close.txt is:
+-----------------+---------+---------+---------+
|       Task      |    P    |    R    |    F1   |
+-----------------+---------+---------+---------+
|Word segmentation| 94.6465 | 92.1817 | 93.3978 |
+-----------------+---------+---------+---------+
|   Pos tagging   | 88.2383 | 85.9404 | 87.0742 |
+-----------------+---------+---------+---------+

昨日の結果に較べて、TestAは少し悪くなっているものの、TestBは少し良くなっている。タマタマそういう結果なのだが、これなら「closed modality」だと信じてもらえるだろうか。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

UNIXはシンプルである。必要なのはそのシンプルさを理解する素質だけである -- Dennis Ritchie

読み込み中...