Created
April 21, 2013 14:42
-
-
Save y-tag/5429834 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdio> | |
#include <cstdlib> | |
#include <cfloat> | |
#include <cstring> | |
#include <iostream> | |
#include <string> | |
#include <utility> | |
#include <vector> | |
#include <fstream> | |
#include <jubatus/client.hpp> | |
// g++ -O2 -o eval_ranking eval_ranking.cpp `pkg-config pficommon --libs --cflags` -lmsgpack -ljubatus_mpio -ljubatus_msgpack-rpc -std=c++0x | |
using jubatus::ranking::datum; | |
int parse_line(const std::string &line, float *relevance, int *qid, datum *d) { | |
d->string_values.clear(); | |
d->num_values.clear(); | |
char cbuff[line.size() + 1]; | |
memmove(cbuff, line.c_str(), line.size() + 1); | |
char *p = strtok(cbuff, " \t"); | |
*relevance = static_cast<float>(strtod(p, NULL)); | |
while (1) { | |
char *f = strtok(NULL, ":"); | |
char *v = strtok(NULL, " \t"); | |
if (v == NULL) { | |
break; | |
} | |
if (std::string(f).substr(0, 3) == "qid") { | |
*qid = static_cast<int>(strtol(v, NULL, 10)); | |
} else { | |
d->num_values.push_back(std::make_pair(std::string(f), strtod(v, NULL))); | |
} | |
} | |
return 1; | |
} | |
int main(int argc, char **argv) { | |
std::string host = "127.0.0.1"; | |
int port = 9199; | |
std::string name = "test"; | |
jubatus::ranking::client::ranking client(host, port, 10.0); | |
client.clear(name); | |
if (argc < 5) { | |
fprintf(stderr, "%s train_in valid_in test_in valid_out test_out\n", argv[0]); | |
exit(1); | |
} | |
const char *train_in = argv[1]; | |
const char *valid_in = argv[2]; | |
const char *test_in = argv[3]; | |
const char *valid_out = argv[4]; | |
const char *test_out = argv[5]; | |
std::string buff; | |
std::map<int, std::vector<std::pair<float, datum> > > train_data; | |
std::ifstream trifs; | |
fprintf(stderr, "read train data...\n"); | |
trifs.open(train_in); | |
while (getline(trifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
train_data[qid].push_back(std::make_pair(rel, d)); | |
} | |
fprintf(stderr, "done\n"); | |
fprintf(stderr, "train start...\n"); | |
std::vector<std::vector<std::pair<float, datum> > > train_data_vec; | |
for (auto train_itr = train_data.begin(); train_itr != train_data.end(); ++train_itr) { | |
train_data_vec.push_back(train_itr->second); | |
} | |
srand(1000); | |
/* | |
for (size_t l = 0; l < loop; ++l) { | |
random_shuffle(train_data_vec.begin(), train_data_vec.end()); | |
for (size_t i = 0; i < train_data_vec.size(); ++i) { | |
client.train(name, train_data_vec[i]); | |
} | |
} | |
*/ | |
size_t n = 0; | |
while (n < 10 * train_data_vec.size()) { | |
int m = rand() % train_data_vec.size(); | |
client.train(name, train_data_vec[m]); | |
n++; | |
} | |
fprintf(stderr, "done\n"); | |
std::vector<datum> valid_data; | |
std::ifstream vaifs; | |
std::ofstream vaofs; | |
fprintf(stderr, "validation start...\n"); | |
vaifs.open(valid_in); | |
vaofs.open(valid_out); | |
while (getline(vaifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
valid_data.push_back(d); | |
auto results = client.estimate(name, valid_data); | |
vaofs << results[0] << std::endl; | |
valid_data.clear(); | |
} | |
fprintf(stderr, "done\n"); | |
std::vector<datum> test_data; | |
std::ifstream teifs; | |
std::ofstream teofs; | |
fprintf(stderr, "test start...\n"); | |
teifs.open(test_in); | |
teofs.open(test_out); | |
while (getline(teifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
test_data.push_back(d); | |
auto results = client.estimate(name, test_data); | |
teofs << results[0] << std::endl; | |
test_data.clear(); | |
} | |
fprintf(stderr, "done\n"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment