Created
February 11, 2021 00:26
-
-
Save yoshoku/2e4c2edc045ca5b7979fc702364e7302 to your computer and use it in GitHub Desktop.
Image Recognition with VGG-16 Network in Ruby (ja)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'magro' | |
require 'json' | |
require 'torch' | |
require 'torchvision' | |
# 学習済みのVGG-16 Networkを読み込む. | |
vgg = TorchVision::Models::VGG16.new | |
vgg.load_state_dict(Torch.load('vgg16_.pth')) | |
# 画像を読み込む. | |
img = Magro::IO.imread('A_Golden_Retriever-9_(Barras).JPG') | |
# 画像の中心を正方形に切り出す. | |
height, width, = img.shape | |
img_size = [height, width].min | |
y_offset = (height - img_size) / 2 | |
x_offset = (width - img_size) / 2 | |
img = img[y_offset...(y_offset + img_size), x_offset...(x_offset + img_size), true] | |
# 画像を224x224の大きさにする. | |
img = Magro::Transform.resize(img, height: 224, width: 224) | |
# 画素値を[0, 1]の範囲に正規化する. | |
img = Numo::SFloat.cast(img) / 255.0 | |
# 画像をtorch.rbのtensorに変換し, [チャンネル, 高さ, 幅]の順に入れ替える. | |
img_torch = Torch.from_numo(img).permute(2, 0, 1) | |
# 平均と標準偏差を正規化する. | |
mean = Torch.tensor([0.485, 0.456, 0.406]) | |
std = Torch.tensor([0.229, 0.224, 0.225]) | |
normalize = TorchVision::Transforms::Normalize.new(mean, std) | |
normalize.call(img_torch) | |
# tensorを [1, 3, 224, 224] の形にする. | |
img_torch = img_torch.expand(1, -1, -1, -1) | |
# 学習済みモデルに、前処理した画像を入力する. | |
vgg.eval | |
out = vgg.forward(img_torch) | |
# 最終層の出力で最も値の大きい要素の添字を得る. | |
class_idx = out.numo[0, true].max_index | |
# 添字に対応するImageNetのクラスを出力する. | |
imagenet_classes = JSON.load(File.read('imagenet_class_index.json')) | |
puts "class: #{imagenet_classes[class_idx.to_s].last}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment