yasuokaの日記: roberta-base-japanese-aozora-ud-goeswithで見る「しゃぶしゃぶを繁華街に食べに行く」の係り受け隣接行列ロジット
思うところあって、国語研長単位係り受け解析モデルroberta-base-japanese-aozora-ud-goeswithを作ってみた。Google Colaboratoryで動かしてみよう。
!pip install transformers ufal.chu_liu_edmonds deplacy
class UDgoeswith(object):
def __init__(self,bert):
from transformers import AutoTokenizer,AutoModelForTokenClassification
self.tokenizer=AutoTokenizer.from_pretrained(bert)
self.model=AutoModelForTokenClassification.from_pretrained(bert)
def __call__(self,text):
import numpy,torch,ufal.chu_liu_edmonds
w=self.tokenizer(text,return_offsets_mapping=True)
v=w["input_ids"]
n=len(v)-1
with torch.no_grad():
d=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[v[i]] for i in range(1,n)]))
e=d.logits.numpy()[:,1:n,:]
e[:,:,0]=numpy.nan
m=numpy.full((n,n),numpy.nan)
m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
p=numpy.zeros((n,n))
p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
for i in range(1,n):
m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
u="# text = "+text+"\n"
v=[(s,e) for s,e in w["offset_mapping"] if s<e]
for i,(s,e) in enumerate(v,1):
q=self.model.config.id2label[p[i,h[i]]].split("|")
u+="\t".join([str(i),text[s:e],"_",q[0],"_","|".join(q[1:-1]),str(h[i]),q[-1],"_","_" if i<len(v) and e<v[i][0] else "SpaceAfter=No"])+"\n"
return u+"\n"
nlp=UDgoeswith("KoichiYasuoka/roberta-base-japanese-aozora-ud-goeswith")
doc=nlp("しゃぶしゃぶを繁華街に食べに行く")
import deplacy
deplacy.render(doc,Japanese=True)
deplacy.serve(doc,port=None)
「しゃぶしゃぶを繁華街に食べに行く」を係り受け解析したところ、私(安岡孝一)の手元では以下の結果になった。
しゃぶ NOUN ═╗═╗<╗ obj(目的語)
しゃぶ X <╝ ║ ║ goeswith(泣き別れ)
を ADP <══╝ ║ case(格表示)
繁華 NOUN ═╗═╗ ║<══╗ obl(斜格補語)
街 X <╝ ║ ║ ║ goeswith(泣き別れ)
に ADP <══╝ ║ ║ case(格表示)
食べ VERB ═╗═══╝<╗ ║ advcl(連用修飾節)
に ADP <╝ ║ ║ case(格表示)
行く VERB ═══════╝═╝ root(親)
# text = しゃぶしゃぶを繁華街に食べに行く
1 しゃぶ _ NOUN _ _ 7 obj _ SpaceAfter=No
2 しゃぶ _ X _ _ 1 goeswith _ SpaceAfter=No
3 を _ ADP _ _ 1 case _ SpaceAfter=No
4 繁華 _ NOUN _ _ 9 obl _ SpaceAfter=No
5 街 _ X _ _ 4 goeswith _ SpaceAfter=No
6 に _ ADP _ _ 4 case _ SpaceAfter=No
7 食べ _ VERB _ _ 9 advcl _ SpaceAfter=No
8 に _ ADP _ _ 7 case _ SpaceAfter=No
9 行く _ VERB _ _ 0 root _ SpaceAfter=No
SVGで可視化すると、こんな感じ。本来1語であるべきサブワードの「しゃぶ」と「しゃぶ」がgoeswith(泣き別れ)で繋がれていて、しかも係り受けがうまく交差している。これに加えて、このモデルの売りは、内部の係り受け隣接行列を、剝き身で見ることができるのだ。ちょっと見てみよう。
!pip install transformers
import torch,numpy
from transformers import AutoTokenizer,AutoModelForTokenClassification
txt="しゃぶしゃぶを繁華街に食べに行く"
brt="KoichiYasuoka/roberta-base-japanese-aozora-ud-goeswith"
tkz=AutoTokenizer.from_pretrained(brt)
mdl=AutoModelForTokenClassification.from_pretrained(brt)
v,l=tkz(txt,return_offsets_mapping=True),mdl.config.id2label
w,u=v["input_ids"],[txt[s:e] for s,e in v["offset_mapping"] if s<e]
x=[w[:i]+[tkz.mask_token_id]+w[i+1:]+[w[i]] for i in range(1,len(w)-1)]
with torch.no_grad():
m=mdl(input_ids=torch.tensor(x)).logits.numpy()[:,1:len(w)-1,1:]
d,p=numpy.max(m,axis=2),numpy.argmax(m,axis=2)+1
print(" ".join(x.rjust(9-len(x)) for x in u))
for i,j in enumerate(u):
print("\n"+" ".join("{:9.3f}".format(x) for x in d[i])," ",j)
print(" ".join(l[x].split("|")[-1][:9].rjust(9) for x in p[i]))
隣接行列のロジット(対数オッズ)は、以下の結果になった。
しゃぶ しゃぶ を 繁華 街 に 食べ に 行く
-0.083 11.656 12.093 -0.699 -0.850 -0.776 -1.139 -1.010 -0.271 しゃぶ
obj goeswith case punct punct case punct punct punct
-1.168 -1.449 1.025 -0.987 -1.279 -0.998 -1.074 -1.251 -0.969 しゃぶ
punct goeswith case obl goeswith case advcl case punct
-0.369 -1.019 -1.440 -0.695 -0.620 -0.681 -1.154 -1.191 -1.156 を
obl goeswith goeswith obl goeswith case advcl case advcl
2.753 -0.352 -0.924 0.065 11.679 12.137 -0.319 0.465 -0.547 繁華
obj goeswith case obl goeswith case goeswith case punct
-0.313 -1.049 -0.876 -0.669 -1.529 1.399 -0.833 -0.698 -0.816 街
obj goeswith case obl goeswith case goeswith case punct
0.009 -0.895 -1.091 -0.012 -1.084 -1.176 0.199 0.141 -0.859 に
obj goeswith case obl goeswith case goeswith case punct
7.871 -0.801 -0.697 7.849 -0.872 -0.539 1.091 11.413 1.810 食べ
obj goeswith case obl obl case root case advcl
-0.351 -0.810 -1.020 0.232 -0.974 -1.161 0.529 -1.258 0.637 に
obj goeswith case obl obl case advcl goeswith goeswith
6.927 -0.486 -0.927 7.938 0.004 -0.990 8.629 -0.630 6.523 行く
obj goeswith case obl obl case advcl punct root
「しゃぶ」の列も「繁華」の列も、「食べ」の行と「行く」の行のロジットがかなり拮抗していて、ギリギリの線で「しゃぶ」⇐「食べ」と「繁華」⇐「行く」になっているのがわかる。さて、これで、チューニングがもう少し楽になるかな。
roberta-base-japanese-aozora-ud-goeswithで見る「しゃぶしゃぶを繁華街に食べに行く」の係り受け隣接行列ロジット More ログイン