パスワードを忘れた? アカウント作成
15823783 journal
人工知能

yasuokaの日記: roberta-base-japanese-aozora-ud-goeswithで見る「しゃぶしゃぶを繁華街に食べに行く」の係り受け隣接行列ロジット

日記 by yasuoka

思うところあって、国語研長単位係り受け解析モデルroberta-base-japanese-aozora-ud-goeswith作ってみた。Google Colaboratoryで動かしてみよう。

!pip install transformers ufal.chu_liu_edmonds deplacy
class UDgoeswith(object):
  def __init__(self,bert):
    from transformers import AutoTokenizer,AutoModelForTokenClassification
    self.tokenizer=AutoTokenizer.from_pretrained(bert)
    self.model=AutoModelForTokenClassification.from_pretrained(bert)
  def __call__(self,text):
    import numpy,torch,ufal.chu_liu_edmonds
    w=self.tokenizer(text,return_offsets_mapping=True)
    v=w["input_ids"]
    n=len(v)-1
    with torch.no_grad():
      d=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[v[i]] for i in range(1,n)]))
    e=d.logits.numpy()[:,1:n,:]
    e[:,:,0]=numpy.nan
    m=numpy.full((n,n),numpy.nan)
    m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
    p=numpy.zeros((n,n))
    p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
    for i in range(1,n):
      m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
    h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
    u="# text = "+text+"\n"
    v=[(s,e) for s,e in w["offset_mapping"] if s<e]
    for i,(s,e) in enumerate(v,1):
      q=self.model.config.id2label[p[i,h[i]]].split("|")
      u+="\t".join([str(i),text[s:e],"_",q[0],"_","|".join(q[1:-1]),str(h[i]),q[-1],"_","_" if i<len(v) and e<v[i][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
nlp=UDgoeswith("KoichiYasuoka/roberta-base-japanese-aozora-ud-goeswith")
doc=nlp("しゃぶしゃぶを繁華街に食べに行く")
import deplacy
deplacy.render(doc,Japanese=True)
deplacy.serve(doc,port=None)

「しゃぶしゃぶを繁華街に食べに行く」を係り受け解析したところ、私(安岡孝一)の手元では以下の結果になった。

しゃぶ NOUN ═╗═╗<╗     obj(目的語)
しゃぶ X    <╝ ║ ║     goeswith(泣き別れ)
を     ADP  <══╝ ║     case(格表示)
繁華   NOUN ═╗═╗ ║<══╗ obl(斜格補語)
街     X    <╝ ║ ║   ║ goeswith(泣き別れ)
に     ADP  <══╝ ║   ║ case(格表示)
食べ   VERB ═╗═══╝<╗ ║ advcl(連用修飾節)
に     ADP  <╝     ║ ║ case(格表示)
行く   VERB ═══════╝═╝ root(親)

# text = しゃぶしゃぶを繁華街に食べに行く
1    しゃぶ    _    NOUN    _    _    7    obj    _    SpaceAfter=No
2    しゃぶ    _    X    _    _    1    goeswith    _    SpaceAfter=No
3    を    _    ADP    _    _    1    case    _    SpaceAfter=No
4    繁華    _    NOUN    _    _    9    obl    _    SpaceAfter=No
5    街    _    X    _    _    4    goeswith    _    SpaceAfter=No
6    に    _    ADP    _    _    4    case    _    SpaceAfter=No
7    食べ    _    VERB    _    _    9    advcl    _    SpaceAfter=No
8    に    _    ADP    _    _    7    case    _    SpaceAfter=No
9    行く    _    VERB    _    _    0    root    _    SpaceAfter=No

SVGで可視化すると、こんな感じ。本来1語であるべきサブワードの「しゃぶ」と「しゃぶ」がgoeswith(泣き別れ)で繋がれていて、しかも係り受けがうまく交差している。これに加えて、このモデルの売りは、内部の係り受け隣接行列を、剝き身で見ることができるのだ。ちょっと見てみよう。

!pip install transformers
import torch,numpy
from transformers import AutoTokenizer,AutoModelForTokenClassification
txt="しゃぶしゃぶを繁華街に食べに行く"
brt="KoichiYasuoka/roberta-base-japanese-aozora-ud-goeswith"
tkz=AutoTokenizer.from_pretrained(brt)
mdl=AutoModelForTokenClassification.from_pretrained(brt)
v,l=tkz(txt,return_offsets_mapping=True),mdl.config.id2label
w,u=v["input_ids"],[txt[s:e] for s,e in v["offset_mapping"] if s<e]
x=[w[:i]+[tkz.mask_token_id]+w[i+1:]+[w[i]] for i in range(1,len(w)-1)]
with torch.no_grad():
  m=mdl(input_ids=torch.tensor(x)).logits.numpy()[:,1:len(w)-1,1:]
d,p=numpy.max(m,axis=2),numpy.argmax(m,axis=2)+1
print(" ".join(x.rjust(9-len(x)) for x in u))
for i,j in enumerate(u):
  print("\n"+" ".join("{:9.3f}".format(x) for x in d[i])," ",j)
  print(" ".join(l[x].split("|")[-1][:9].rjust(9) for x in p[i]))

隣接行列のロジット(対数オッズ)は、以下の結果になった。

   しゃぶ    しゃぶ        を      繁華        街        に      食べ        に      行く

   -0.083    11.656    12.093    -0.699    -0.850    -0.776    -1.139    -1.010    -0.271   しゃぶ
      obj  goeswith      case     punct     punct      case     punct     punct     punct

   -1.168    -1.449     1.025    -0.987    -1.279    -0.998    -1.074    -1.251    -0.969   しゃぶ
    punct  goeswith      case       obl  goeswith      case     advcl      case     punct

   -0.369    -1.019    -1.440    -0.695    -0.620    -0.681    -1.154    -1.191    -1.156   を
      obl  goeswith  goeswith       obl  goeswith      case     advcl      case     advcl

    2.753    -0.352    -0.924     0.065    11.679    12.137    -0.319     0.465    -0.547   繁華
      obj  goeswith      case       obl  goeswith      case  goeswith      case     punct

   -0.313    -1.049    -0.876    -0.669    -1.529     1.399    -0.833    -0.698    -0.816   街
      obj  goeswith      case       obl  goeswith      case  goeswith      case     punct

    0.009    -0.895    -1.091    -0.012    -1.084    -1.176     0.199     0.141    -0.859   に
      obj  goeswith      case       obl  goeswith      case  goeswith      case     punct

    7.871    -0.801    -0.697     7.849    -0.872    -0.539     1.091    11.413     1.810   食べ
      obj  goeswith      case       obl       obl      case      root      case     advcl

   -0.351    -0.810    -1.020     0.232    -0.974    -1.161     0.529    -1.258     0.637   に
      obj  goeswith      case       obl       obl      case     advcl  goeswith  goeswith

    6.927    -0.486    -0.927     7.938     0.004    -0.990     8.629    -0.630     6.523   行く
      obj  goeswith      case       obl       obl      case     advcl     punct      root

「しゃぶ」の列も「繁華」の列も、「食べ」の行と「行く」の行のロジットがかなり拮抗していて、ギリギリの線で「しゃぶ」⇐「食べ」と「繁華」⇐「行く」になっているのがわかる。さて、これで、チューニングがもう少し楽になるかな。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

あと、僕は馬鹿なことをするのは嫌いですよ (わざとやるとき以外は)。-- Larry Wall

読み込み中...