Last active
March 2, 2022 05:25
-
-
Save Unbinilium/d9de5b2d544ad5fe222b0133981687ae to your computer and use it in GitHub Desktop.
Inferring MNIST Torchscript model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <string> | |
#include <vector> | |
#include <utility> | |
#include <torch/torch.h> | |
#include <torch/script.h> | |
#include <opencv2/core.hpp> | |
#include <opencv2/imgproc.hpp> | |
#include <opencv2/highgui.hpp> | |
namespace dnn { | |
template <typename T> | |
class ts_mnist { | |
public: | |
ts_mnist( | |
const std::string& path, | |
const std::vector<T>& labels | |
) : _path(path), _labels(labels) { | |
_module = torch::jit::load(_path); | |
_inputs.resize(1); | |
} | |
auto inferring(const cv::Mat& image) noexcept { | |
cv::cvtColor(image, _gray, cv::COLOR_BGR2GRAY); | |
cv::resize(_gray, _gray, cv::Size(28, 28)); | |
_tensor_image = torch::from_blob(_gray.data, { _gray.rows, _gray.cols }, torch::kUInt8); | |
_tensor_image_normed = (_tensor_image / 255.f).sub_(0.5f).div_(0.5f); | |
_inputs[0] = _tensor_image_normed.unsqueeze_(0).unsqueeze_(0); | |
_output = _module.forward(_inputs).toTensor(); | |
_index = _output.argmax().item<int32_t>(); | |
return std::pair<T, torch::Tensor>(_labels.at(_index), _output.index({0, _index})); | |
} | |
private: | |
const std::string _path; | |
const std::vector<T> _labels; | |
torch::jit::script::Module _module; | |
torch::Tensor _tensor_image; | |
torch::Tensor _tensor_image_normed; | |
torch::Tensor _output; | |
int32_t _index; | |
std::vector<torch::jit::IValue> _inputs; | |
cv::Mat _gray; | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment