Created
October 2, 2016 10:51
-
-
Save saitodev/3fcfc1b3b5ef05ece1ee47c639280687 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// -*- coding: utf-8 -*- | |
#include <iostream> | |
#include <fstream> | |
#include <memory> | |
#include <vector> | |
#include <cassert> | |
#include <cstdint> | |
#include <boost/iostreams/filtering_stream.hpp> | |
#include <boost/iostreams/filter/gzip.hpp> | |
#include "tensorflow/core/platform/init_main.h" | |
#include "tensorflow/core/public/session.h" | |
#include "tensorflow/core/util/command_line_flags.h" | |
tensorflow::Status | |
LoadGraph(tensorflow::string graph_filename, | |
std::unique_ptr<tensorflow::Session>* session) | |
{ | |
tensorflow::GraphDef graph_def; | |
tensorflow::Status load_graph_status = | |
ReadBinaryProto(tensorflow::Env::Default(), graph_filename, &graph_def); | |
if (!load_graph_status.ok()) { | |
return tensorflow::errors::NotFound("Failed to load compute graph at '", | |
graph_filename, "'"); | |
} | |
session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); | |
tensorflow::Status session_create_status = (*session)->Create(graph_def); | |
if (!session_create_status.ok()) { | |
return session_create_status; | |
} | |
return tensorflow::Status::OK(); | |
} | |
tensorflow::Tensor | |
LoadMnistImages(tensorflow::string filename) | |
{ | |
const int N_header = 16; | |
const int N_data = 10000; | |
const int N_width = 28; | |
const int N_height = 28; | |
const int N_vec = N_width * N_height; | |
auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({N_data, N_vec})); | |
auto mat = tensor.tensor<float, 2>(); | |
mat.setZero(); | |
std::ifstream fin(filename, std::ios_base::in | std::ios_base::binary); | |
assert(!fin.fail()); | |
boost::iostreams::filtering_istream s; | |
s.push(boost::iostreams::gzip_decompressor()); | |
s.push(fin); | |
char c; | |
for (int i=0; i<N_header; ++i) { | |
s.get(c); | |
} | |
for (int n=0; n<N_data; ++n) { | |
for (int h=0; h<N_height; ++h) { | |
for (int w=0; w<N_width; ++w) { | |
s.get(c); | |
mat(n, h*N_width + w) = static_cast<float>(static_cast<uint8_t>(c)) / 255.0; | |
} | |
} | |
} | |
return tensor; | |
} | |
tensorflow::Tensor | |
LoadMnistLabels(tensorflow::string filename) | |
{ | |
const int N_header = 8; | |
const int N_data = 10000; | |
const int N_vec = 10; | |
auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({N_data, N_vec})); | |
auto mat = tensor.tensor<float, 2>(); | |
mat.setZero(); | |
std::ifstream fin(filename, std::ios_base::in | std::ios_base::binary); | |
assert(!fin.fail()); | |
boost::iostreams::filtering_istream s; | |
s.push(boost::iostreams::gzip_decompressor()); | |
s.push(fin); | |
char c; | |
for (int i=0; i<N_header; ++i) { | |
s.get(c); | |
} | |
for (int n=0; n<N_data; ++n) { | |
s.get(c); | |
assert((c >= 0) && (c < 10)); | |
mat(n, c) = 1.0; | |
} | |
return tensor; | |
} | |
int main(int argc, char* argv[]) | |
{ | |
tensorflow::string graph_filename = "trained_graph.pb"; | |
tensorflow::string image_filename = "MNIST_data/t10k-images-idx3-ubyte.gz"; | |
tensorflow::string label_filename = "MNIST_data/t10k-labels-idx1-ubyte.gz"; | |
const bool parse_result = tensorflow::ParseFlags( | |
&argc, argv, | |
{tensorflow::Flag("graph", &graph_filename), | |
tensorflow::Flag("image", &image_filename), | |
tensorflow::Flag("label", &label_filename)}); | |
if (!parse_result) { | |
LOG(ERROR) << "Error parsing command-line flags."; | |
return -1; | |
} | |
tensorflow::port::InitMain(argv[0], &argc, &argv); | |
if (argc > 1) { | |
LOG(ERROR) << "Unknown argument " << argv[1]; | |
return -1; | |
} | |
std::unique_ptr<tensorflow::Session> session; | |
auto load_graph_status = LoadGraph(graph_filename, &session); | |
if (!load_graph_status.ok()) { | |
LOG(ERROR) << load_graph_status.error_message(); | |
return -1; | |
} | |
auto x = LoadMnistImages(image_filename); | |
auto y_ = LoadMnistLabels(label_filename); | |
std::vector<tensorflow::Tensor> outputs; | |
auto session_run_status = session->Run({{"x:0", x}, {"y_:0", y_}}, | |
{"accuracy:0"}, | |
{}, | |
&outputs); | |
if (!session_run_status.ok()) { | |
LOG(ERROR) << session_run_status.error_message(); | |
return -1; | |
} | |
float accuracy = outputs[0].scalar<float>()(0); | |
std::cout << "accuracy = " << accuracy << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment