パスワードを忘れた? アカウント作成
15826960 journal
人工知能

yasuokaの日記: roberta-base-english-ud-goeswithで見る英文係り受け隣接行列ロジット

日記 by yasuoka

10月15日の日記に続いて、英語係り受け解析モデルroberta-base-english-ud-goeswith作ってみた。Google Colaboratoryで動かしてみよう。

!pip install transformers ufal.chu_liu_edmonds deplacy
class UDgoeswith(object):
  def __init__(self,bert):
    from transformers import AutoTokenizer,AutoModelForTokenClassification
    self.tokenizer=AutoTokenizer.from_pretrained(bert)
    self.model=AutoModelForTokenClassification.from_pretrained(bert)
  def __call__(self,text):
    import numpy,torch,ufal.chu_liu_edmonds
    w=self.tokenizer(text,return_offsets_mapping=True)
    v=w["input_ids"]
    x=[v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor(x)).logits.numpy()[:,1:-2,:]
    r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
    e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
    m=numpy.full((e.shape[0]+1,e.shape[1]+1),numpy.nan)
    m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
    p=numpy.zeros(m.shape)
    p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
    for i in range(1,m.shape[0]):
      m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
    h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
    if [0 for i in h if i==0]!=[0]:
      m[:,0]+=numpy.where(m[:,0]<numpy.nanmax(m[:,0]),numpy.nan,0)
      h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
    u="# text = "+text+"\n"
    v=[(s,e) for s,e in w["offset_mapping"] if s<e]
    for i,(s,e) in enumerate(v,1):
      q=self.model.config.id2label[p[i,h[i]]].split("|")
      u+="\t".join([str(i),text[s:e],"_",q[0],"_","|".join(q[1:-1]),str(h[i]),q[-1],"_","_" if i<len(v) and e<v[i][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
nlp=UDgoeswith("KoichiYasuoka/roberta-base-english-ud-goeswith")
doc=nlp("In Sendai who did you wait for?")
import deplacy
deplacy.render(doc)
deplacy.serve(doc,port=None)

「In Sendai who did you wait for?」を係り受け解析したところ、私(安岡孝一)の手元では以下の結果になった。

In   ADP   <══╗         case
Send PROPN ═╗═╝<╗       obl
ai   X     <╝   ║       goeswith
who  PRON  ═════║═══╗<╗ obl
did  AUX   <══╗ ║   ║ ║ aux
you  PRON  <╗ ║ ║   ║ ║ nsubj
wait VERB  ═╝═╝═╝═╗═══╝ root
for  ADP   <══════║═╝   case
?    PUNCT <══════╝     punct

# text = In Sendai who did you wait for?
1    In    _    ADP    _    _    2    case    _    _
2    Send    _    PROPN    _    Number=Sing    7    obl    _    SpaceAfter=No
3    ai    _    X    _    _    2    goeswith    _    _
4    who    _    PRON    _    PronType=Int    7    obl    _    _
5    did    _    AUX    _    Mood=Ind|Tense=Past|VerbForm=Fin    7    aux    _    _
6    you    _    PRON    _    _    7    nsubj    _    _
7    wait    _    VERB    _    VerbForm=Inf    0    root    _    _
8    for    _    ADP    _    _    4    case    _    SpaceAfter=No
9    ?    _    PUNCT    _    _    7    punct    _    SpaceAfter=No

SVGで可視化すると、こんな感じ。本来1語であるべき「Send」「ai」がgoeswith(泣き別れ)で繋がれていて、しかも「for」の係り受けがうまく交差している。もちろん、このモデルにおいても、内部の係り受け隣接行列を剝き身で見ることができる。ちょっと見てみよう。

!pip install transformers
import torch,numpy
from transformers import AutoTokenizer,AutoModelForTokenClassification
txt="In Sendai who did you wait for?"
brt="KoichiYasuoka/roberta-base-english-ud-goeswith"
tkz=AutoTokenizer.from_pretrained(brt)
mdl=AutoModelForTokenClassification.from_pretrained(brt)
v,l=tkz(txt,return_offsets_mapping=True),mdl.config.id2label
w,u=v["input_ids"],[txt[s:e] for s,e in v["offset_mapping"] if s<e]
x=[w[:i]+[tkz.mask_token_id]+w[i+1:]+[j] for i,j in enumerate(w[1:-1],1)]
with torch.no_grad():
  m=mdl(input_ids=torch.tensor(x)).logits.numpy()[:,1:-2,:]
r=[1 if i==0 else -1 if l[i].endswith("|root") else 0 for i in range(len(l))]
m+=numpy.where(numpy.add.outer(numpy.identity(m.shape[0]),r)==0,0,numpy.nan)
d,p=numpy.nanmax(m,axis=2),numpy.nanargmax(m,axis=2)
print(" ".join(x[:9].rjust(9) for x in u))
for i,j in enumerate(u):
  print("\n"+" ".join("{:9.3f}".format(x) for x in d[i])," ",j)
  print(" ".join(l[x].split("|")[-1][:9].rjust(9) for x in p[i]))

隣接行列のロジット(対数オッズ)は、以下の結果になった。

       In      Send        ai       who       did       you      wait       for         ?

    1.327     4.288     2.600     2.184     1.790     2.321     4.332     9.080     7.557   In
     root      nmod  goeswith       obj     punct     punct acl:relcl      case     punct

   16.037     5.150    18.021     6.877     2.457     3.484     5.519    10.393     8.274   Send
     case      root  goeswith       obj       aux     punct acl:relcl      case     punct

    3.192     2.593     1.640     2.150     1.640     1.841     1.370     6.699     6.378   ai
     case  compound      root  goeswith      nmod  goeswith parataxis      case  goeswith

    8.258     8.220     2.147     2.297     2.162     3.264     4.715    11.902     8.481   who
     case      nmod      case      root       aux     nsubj acl:relcl      case     punct

    4.999     5.457     2.181     4.621     1.293     2.893     2.874     9.328     7.291   did
     case      nmod      case       obj      root     punct acl:relcl      case     punct

    4.649     7.236     2.163     3.127     1.573     1.225     4.021     9.544     6.739   you
     case      nmod      nmod       obj      nmod      root       acl      case     punct

    4.544    13.536     3.379    11.626    12.723    12.845    15.045     8.809    15.645   wait
     case       obl     punct       obl       aux     nsubj      root      case     punct

    3.974     4.175     3.073     4.627     1.688     2.486     2.105     1.952     7.491   for
goeswith       obl  goeswith       obj       aux     punct      conj      root  goeswith

    2.489     2.391     2.534     1.580     1.786     1.961     1.388     9.512     2.031   ?
goeswith      nmod  goeswith  goeswith      nmod  goeswith  goeswith      case      root

「for」の列のロジットがかなり拮抗しており、僅差で「who」⇒「for」を選んでいるのがわかる。ただ、このモデルでは縮約語がうまく扱えないので、ちょっと悩ましい。さて、そのあたり、どうすればいいかな。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

アレゲはアレゲを呼ぶ -- ある傍観者

読み込み中...