yasuokaの日記: roberta-base-english-ud-goeswithで見る英文係り受け隣接行列ロジット
10月15日の日記に続いて、英語係り受け解析モデルroberta-base-english-ud-goeswithも作ってみた。Google Colaboratoryで動かしてみよう。
!pip install transformers ufal.chu_liu_edmonds deplacy
class UDgoeswith(object):
def __init__(self,bert):
from transformers import AutoTokenizer,AutoModelForTokenClassification
self.tokenizer=AutoTokenizer.from_pretrained(bert)
self.model=AutoModelForTokenClassification.from_pretrained(bert)
def __call__(self,text):
import numpy,torch,ufal.chu_liu_edmonds
w=self.tokenizer(text,return_offsets_mapping=True)
v=w["input_ids"]
x=[v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]
with torch.no_grad():
e=self.model(input_ids=torch.tensor(x)).logits.numpy()[:,1:-2,:]
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
m=numpy.full((e.shape[0]+1,e.shape[1]+1),numpy.nan)
m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
p=numpy.zeros(m.shape)
p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
for i in range(1,m.shape[0]):
m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
if [0 for i in h if i==0]!=[0]:
m[:,0]+=numpy.where(m[:,0]<numpy.nanmax(m[:,0]),numpy.nan,0)
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
u="# text = "+text+"\n"
v=[(s,e) for s,e in w["offset_mapping"] if s<e]
for i,(s,e) in enumerate(v,1):
q=self.model.config.id2label[p[i,h[i]]].split("|")
u+="\t".join([str(i),text[s:e],"_",q[0],"_","|".join(q[1:-1]),str(h[i]),q[-1],"_","_" if i<len(v) and e<v[i][0] else "SpaceAfter=No"])+"\n"
return u+"\n"
nlp=UDgoeswith("KoichiYasuoka/roberta-base-english-ud-goeswith")
doc=nlp("In Sendai who did you wait for?")
import deplacy
deplacy.render(doc)
deplacy.serve(doc,port=None)
「In Sendai who did you wait for?」を係り受け解析したところ、私(安岡孝一)の手元では以下の結果になった。
In ADP <══╗ case
Send PROPN ═╗═╝<╗ obl
ai X <╝ ║ goeswith
who PRON ═════║═══╗<╗ obl
did AUX <══╗ ║ ║ ║ aux
you PRON <╗ ║ ║ ║ ║ nsubj
wait VERB ═╝═╝═╝═╗═══╝ root
for ADP <══════║═╝ case
? PUNCT <══════╝ punct
# text = In Sendai who did you wait for?
1 In _ ADP _ _ 2 case _ _
2 Send _ PROPN _ Number=Sing 7 obl _ SpaceAfter=No
3 ai _ X _ _ 2 goeswith _ _
4 who _ PRON _ PronType=Int 7 obl _ _
5 did _ AUX _ Mood=Ind|Tense=Past|VerbForm=Fin 7 aux _ _
6 you _ PRON _ _ 7 nsubj _ _
7 wait _ VERB _ VerbForm=Inf 0 root _ _
8 for _ ADP _ _ 4 case _ SpaceAfter=No
9 ? _ PUNCT _ _ 7 punct _ SpaceAfter=No
SVGで可視化すると、こんな感じ。本来1語であるべき「Send」「ai」がgoeswith(泣き別れ)で繋がれていて、しかも「for」の係り受けがうまく交差している。もちろん、このモデルにおいても、内部の係り受け隣接行列を剝き身で見ることができる。ちょっと見てみよう。
!pip install transformers
import torch,numpy
from transformers import AutoTokenizer,AutoModelForTokenClassification
txt="In Sendai who did you wait for?"
brt="KoichiYasuoka/roberta-base-english-ud-goeswith"
tkz=AutoTokenizer.from_pretrained(brt)
mdl=AutoModelForTokenClassification.from_pretrained(brt)
v,l=tkz(txt,return_offsets_mapping=True),mdl.config.id2label
w,u=v["input_ids"],[txt[s:e] for s,e in v["offset_mapping"] if s<e]
x=[w[:i]+[tkz.mask_token_id]+w[i+1:]+[j] for i,j in enumerate(w[1:-1],1)]
with torch.no_grad():
m=mdl(input_ids=torch.tensor(x)).logits.numpy()[:,1:-2,:]
r=[1 if i==0 else -1 if l[i].endswith("|root") else 0 for i in range(len(l))]
m+=numpy.where(numpy.add.outer(numpy.identity(m.shape[0]),r)==0,0,numpy.nan)
d,p=numpy.nanmax(m,axis=2),numpy.nanargmax(m,axis=2)
print(" ".join(x[:9].rjust(9) for x in u))
for i,j in enumerate(u):
print("\n"+" ".join("{:9.3f}".format(x) for x in d[i])," ",j)
print(" ".join(l[x].split("|")[-1][:9].rjust(9) for x in p[i]))
隣接行列のロジット(対数オッズ)は、以下の結果になった。
In Send ai who did you wait for ?
1.327 4.288 2.600 2.184 1.790 2.321 4.332 9.080 7.557 In
root nmod goeswith obj punct punct acl:relcl case punct
16.037 5.150 18.021 6.877 2.457 3.484 5.519 10.393 8.274 Send
case root goeswith obj aux punct acl:relcl case punct
3.192 2.593 1.640 2.150 1.640 1.841 1.370 6.699 6.378 ai
case compound root goeswith nmod goeswith parataxis case goeswith
8.258 8.220 2.147 2.297 2.162 3.264 4.715 11.902 8.481 who
case nmod case root aux nsubj acl:relcl case punct
4.999 5.457 2.181 4.621 1.293 2.893 2.874 9.328 7.291 did
case nmod case obj root punct acl:relcl case punct
4.649 7.236 2.163 3.127 1.573 1.225 4.021 9.544 6.739 you
case nmod nmod obj nmod root acl case punct
4.544 13.536 3.379 11.626 12.723 12.845 15.045 8.809 15.645 wait
case obl punct obl aux nsubj root case punct
3.974 4.175 3.073 4.627 1.688 2.486 2.105 1.952 7.491 for
goeswith obl goeswith obj aux punct conj root goeswith
2.489 2.391 2.534 1.580 1.786 1.961 1.388 9.512 2.031 ?
goeswith nmod goeswith goeswith nmod goeswith goeswith case root
「for」の列のロジットがかなり拮抗しており、僅差で「who」⇒「for」を選んでいるのがわかる。ただ、このモデルでは縮約語がうまく扱えないので、ちょっと悩ましい。さて、そのあたり、どうすればいいかな。
roberta-base-english-ud-goeswithで見る英文係り受け隣接行列ロジット More ログイン