import shap
def shap_plot(base_model, instance):
"""
Description: SHAP Plot for feature importance, with local and global expalanation.
Created By: Fahim Muntasir
Date: 7/12/23
base_model: Algorithm used for prediction. Tree based algorithms are preferred.
instance: Any random instance from the test set.
Return: Barplot and Force plot for local and global explanation
"""
model = base_model.fit(X_train, y_train) # fit model as usual
explainer = shap.Explainer(model, X_train) # this is the shap part
shap_values = explainer(X_test) # calculating feature importance score
shap.initjs() # this is for the plots to work
print(f"Sample number: {instance}")
preds = model.predict(X_test)
probability = model.predict_proba(X_test)
print(f"Actual class: {y_test.iloc[instance]}")
print(f"Predicted class: {preds[instance]}")
# SHAP Plots
shap.plots.bar(shap_values) # global expalanation
shap.plots.bar(shap_values[instance], max_display=13) # local explanation
# force plot
return shap.plots.force(shap_values[instance])
Using this code is simple AF. Just call the function like below.
shap_plot(XGBClassifier(), 7) # change the instance number according to your liking