Created
October 17, 2014 14:00
-
-
Save jclevesque/e5a018418b22a4c69749 to your computer and use it in GitHub Desktop.
Python wrapper for budgeted svm toolbox
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import subprocess | |
class BudgetSVMToolbox: | |
""" | |
Wrapper around the budgetsvm library as provided by the its authors. | |
Requires two executables, budgetsvm-train and budgetsvm-predict. | |
""" | |
def __init__(self, budgeted_svm_bin_path, epochs=1, algorithm=4, | |
kernel_type=0, L=0.0001, budget=100, budget_strategy=1, gamma=-1, | |
degree=-1, slope=2., intercept=1., bias=False, z=50000, verbose=True, | |
output_folder='.'): | |
''' | |
Parameters: | |
----------- | |
budgeted_svm_bin_path: path to the executables of budgeted svm toolbox | |
(budgetedsvm-train, budgetedsvm-predict). Must be compiled manually. | |
epochs: maximum number of epochs | |
algorithm: 0 - pegasos | |
1 - AMM batch | |
2 - AMM online | |
3 - LLSVM | |
4 - BSGD | |
L: regularization parameter for SVM updates (lambda) | |
gamma: kernel width for RBF (used for BSGD and LLSVM only) | |
budget: budget, maximum number of support vectors (BSGD) or | |
landmark points (LLSVM) | |
budget_strategy: BSGD: 0-removal, 1-merging, LLVSM:0-random, | |
1-kmeans, 2-kmedoids | |
output_folder: intermediate folder in which to save prediction files. | |
other/missing parameters: see budgeted svm toolbox. | |
''' | |
self.epochs = epochs | |
self.algorithm = algorithm | |
self.L = L | |
self.budget = budget | |
self.budget_strategy = budget_strategy | |
self.kernel_type = kernel_type | |
self.gamma = gamma | |
self.degree = degree | |
self.slope = slope | |
self.intercept = intercept | |
self.bias = bias | |
self.z = z | |
#set budgeted-svm to verbose if we are in verbose mode. | |
self.v = verbose | |
self.output_folder = output_folder | |
os.makedirs(output_folder, exist_ok=True) | |
self.exec_path = os.path.join(budgeted_svm_bin_path) | |
def train(self, datafile, d): | |
''' | |
Requires filenames instead of pre-loaded datasets. | |
Parameters: | |
----------- | |
datafile: name of datafile, svmlight format. | |
d: dimensionality of data in datafile | |
''' | |
self.model_file = os.path.join(self.output_folder, 'bsgd_model') | |
self.prediction_file = os.path.join(self.output_folder, 'predictions') | |
#Prepare a bunch of options to give to budgeted svm toolbox. | |
options = ' -e {} -A {} -L {} -B {} -m {} -K {} -z {} -v {}'.format( | |
self.epochs, self.algorithm, self.L, | |
self.budget, self.budget_strategy, self.kernel_type, | |
self.z, self.v) | |
if self.kernel_type == 0 or self.kernel_type == 1: | |
if self.gamma == -1: | |
self.gamma = 1 / d | |
options += ' -g {}'.format(self.gamma) | |
if self.kernel_type == 2: | |
options += ' -d {}'.format(self.degree) | |
if self.kernel_type == 4: | |
options += ' -d {}'.format(self.slope) | |
if self.kernel_type in [2, 4]: | |
options += ' -i {}'.format(self.intercept) | |
cmd = self.exec_path + 'budgetedsvm-train' + options + ' ' + datafile +\ | |
' ' + self.model_file | |
print(cmd) | |
cmd = cmd.split(' ') | |
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE) | |
output, error = process.communicate() | |
output = output.decode() | |
error = error.decode() | |
print(output) | |
if len(error) > 0: | |
raise Exception(error) | |
def test(self, data, prediction_file=None): | |
if prediction_file != None: | |
p_file = prediction_file | |
else: | |
p_file = self.prediction_file | |
#always verbose otherwise we don't get testing accuracy | |
options = '-z {} -v {}'.format(self.z, 1) | |
#run command | |
cmd = self.exec_path + 'budgetedsvm-predict ' + options + ' ' + data +\ | |
' ' + self.model_file + ' ' + p_file | |
print(cmd) | |
cmd = cmd.split(' ') | |
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE) | |
#Waits for program to end, receives output/error | |
output, error = process.communicate() | |
output = output.decode() | |
error = error.decode() | |
print(output) | |
if len(error) > 0: | |
raise Exception(error) | |
#Last line of output contains accuracy... | |
output_lines = output.strip('\n').split('\n') | |
perf_str = output_lines[-1] | |
perf_str = re.findall(r'\d+.\d+', perf_str)[0] | |
accuracy = 1 - float(perf_str) / 100 | |
print('Tested budgeted SVM on given dataset, accuracy : {}'.format(accuracy)) | |
return accuracy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment