パスワードを忘れた? アカウント作成
15800315 journal
人工知能

yasuokaの日記: Swinv2ForImageClassificationで漢字の総画数を求める画像分類モデル

日記 by yasuoka

Transformers 4.22がSwin Transformer V2をサポートしたので、戸籍統一文字の総画数を元に、画像中の漢字に対して画数を求める画像分類モデルを、試作してみることにした。Google Colaboratory (GPU)だと、以下のような感じ。

!pip install 'transformers>=4.22'
import os,glob
url="https://github.com/KoichiYasuoka/kosekimoji-strokes"
d=os.path.basename(url)
!test -d $d || git clone --depth=1 $url
img=glob.glob(d+"/*/*.png")
lid={str(i):i for i in range(65)}
class ImageDataset(object):
  def __init__(self,files,label2id):
    self.files=files
    self.label2id=label2id
  __len__=lambda self:len(self.files)
  def __getitem__(self,i):
    from PIL import Image
    from torchvision.transforms.functional import to_tensor
    f=self.files[i]
    return {"pixel_values":to_tensor(Image.open(f).convert("L")),"labels":[self.label2id[f.split("/")[-2]]]}
from transformers import Swinv2ForImageClassification,Swinv2Config,DefaultDataCollator,TrainingArguments,Trainer
trainDS=ImageDataset(img,lid)
mdl=Swinv2ForImageClassification(Swinv2Config(image_size=200,num_channels=1,num_labels=len(lid),label2id=lid,id2label={i:l for l,i in lid.items()}))
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=32,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
trn=Trainer(args=arg,data_collator=DefaultDataCollator(),model=mdl,train_dataset=trainDS)
trn.train()
trn.save_model("my-kosekimoji")

GPUだと45分程度で、my-kosekimojiに画像分類モデルが出来上がる。うまく出来たら、法人番号5470001008156(金「⿱刀比」羅醬油株式会社)の商号画像(9文字)に対し、各漢字の画数を求めてみよう。

!curl -A Mozilla -L 'https://www.houjin-bangou.nta.go.jp/image?imageid=00005096' -o test.png
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor
from transformers import AutoModelForImageClassification
mdl=AutoModelForImageClassification.from_pretrained("my-kosekimoji")
e=Image.open("test.png")
w,h=e.size
with torch.no_grad():
  x=mdl(torch.stack([to_tensor(e.crop((x,0,x+h,h)).resize((200,200)).convert("L")) for x in range(0,w,h)],0)).logits
print(torch.argmax(x,axis=1).tolist())

私(安岡孝一)の手元では、以下の結果になった。

[8, 8, 19, 19, 8, 11, 6, 6, 7]

うーん、「⿱刀比」が8画、「株」が11画になってしまっていて、もう一息だ。さて、これ、教師画像を増やせば、もうちょっと精度あがるのかな。

この議論は、yasuoka (21275)によって ログインユーザだけとして作成されたが、今となっては 新たにコメントを付けることはできません。
typodupeerror

「毎々お世話になっております。仕様書を頂きたく。」「拝承」 -- ある会社の日常

読み込み中...