Created
April 25, 2022 15:01
-
-
Save hadifar/7b89bc435279829fb7923066a8a63869 to your computer and use it in GitHub Desktop.
A simple example for SVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import numpy as np | |
from joblib import dump, load | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics import accuracy_score | |
from sklearn.pipeline import Pipeline | |
from sklearn.svm import SVC | |
def load_data(args): | |
x_train = ['svm is all you need !!!', 'test is test', 'this is another test', 'this is test 3', | |
'svm is more than you need'] | |
y_train = [0, 1, 1, 1, 0] | |
x_test = ['svm is good ?'] | |
y_test = [0] | |
return x_train, x_test, y_train, y_test | |
def do_inference(args): | |
voting_classifier = load(args.save_path) | |
inps = ['appli onlin medicar access comput right'] # gold class 1 | |
predictions = voting_classifier.predict_proba(inps) | |
predictions = [int(np.argmax(x)) for x in predictions] | |
print(predictions) | |
def train_classifier(args): | |
x_train, x_valid, y_train, y_valid = load_data(args) | |
print('loading data finished...') | |
# hyper_param_search() | |
# check trained file | |
if os.path.isfile(args.save_path): | |
print('load existing stat classifier') | |
pipeline = load(args.save_path) | |
else: | |
pipeline = Pipeline([ | |
("tfidf", TfidfVectorizer(ngram_range=(1, 2))), | |
('svc', SVC(kernel='linear', C=1, probability=True))], verbose=True) | |
pipeline.fit(x_train, y_train) | |
dump(pipeline, args.save_path) | |
predictions = pipeline.predict_proba(x_valid) | |
predictions = [int(np.argmax(x)) for x in predictions] | |
print('ACC ' + str(accuracy_score(y_valid, predictions))) | |
def main(args): | |
# if args.model_name.find('stat') != -1: | |
train_classifier(args) | |
do_inference(args) | |
# do_inference(args) | |
# elif args.model_name.find('bert') != -1: | |
# do_train_and_eval(args) | |
# else: | |
# raise Exception('cls type not found...') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--debug', type=str, | |
default=1) | |
parser.add_argument('--save_path', type=str, default='voting_domain_classifier.joblib', | |
help='two type of classifier: neural-based (e.g., bert) and statistical based (e.g., svm)') | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment