Last active
June 25, 2021 16:50
-
-
Save kyoto-cheng/21ff23cc0019a5b9919845d64484d3b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Fit and transform the Vectorizer based on the feature selection results X_names | |
vectorizer = feature_extraction.text.CountVectorizer(vocabulary=X_names) | |
vectorizer.fit(corpus) | |
X_train = vectorizer.transform(corpus) | |
# Testing ML models are Naive Bayes, Random Forest and Decision Trees | |
NB_Classifier = naive_bayes.MultinomialNB() | |
RForest_Classifier = RandomForestClassifier() | |
DTree_Classifier = DecisionTreeClassifier() | |
# NB_Classifier Pipeline with CountVectorizer | |
NB_Pipeline = pipeline.Pipeline([("vectorizer", vectorizer), | |
("classifier", NB_Classifier)]) | |
# RForest_Classifier Pipeline with CountVectorizer | |
RForest_Pipeline = pipeline.Pipeline([("vectorizer", vectorizer), | |
("classifier", RForest_Classifier)]) | |
# DTree_Classifier Pipeline with CountVectorizer | |
DTree_Pipeline = pipeline.Pipeline([("vectorizer", vectorizer), | |
("classifier", DTree_Classifier)]) | |
# Define a plot function returns the heatmap of selected model pipeline's confusion matrix | |
def model_plot(pipeline): | |
# train classifier | |
pipeline["classifier"].fit(X_train, y_train) | |
# test | |
X_test = df_test.Questions.values | |
predicted = pipeline.predict(X_test) | |
predicted_prob = pipeline.predict_proba(X_test) | |
# Plot confusion matrix | |
cm = confusion_matrix(y_test, predicted) | |
ax= plt.subplot() | |
sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap="BuPu"); #annot=True to annotate cells, ftm='g' to disable scientific notation | |
# labels, title and ticks | |
ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels'); | |
ax.set_title('Confusion Matrix'); | |
ax.xaxis.set_ticklabels(['EDA', 'Method', 'Model', 'Statistics']); ax.yaxis.set_ticklabels(['EDA', 'Method', 'Model', 'Statistics']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment