Skip to content

Instantly share code, notes, and snippets.

@fahim-muntasir-niloy
Created December 7, 2023 16:36
Show Gist options
  • Save fahim-muntasir-niloy/72af4ea87e99dad460fdab46b46614b2 to your computer and use it in GitHub Desktop.
Save fahim-muntasir-niloy/72af4ea87e99dad460fdab46b46614b2 to your computer and use it in GitHub Desktop.
SHAP Explainable AI - Easy Modular Code
import shap

def shap_plot(base_model, instance):
    """
    Description: SHAP Plot for feature importance, with local and global expalanation.
    Created By: Fahim Muntasir
    Date: 7/12/23
    
    base_model: Algorithm used for prediction. Tree based algorithms are preferred.
    instance: Any random instance from the test set.
    Return: Barplot and Force plot for local and global explanation
    """

    model = base_model.fit(X_train, y_train)      # fit model as usual

    explainer = shap.Explainer(model, X_train)   # this is the shap part
    shap_values = explainer(X_test)  # calculating feature importance score

    shap.initjs()           # this is for the plots to work
    print(f"Sample number: {instance}")

    preds = model.predict(X_test)
    probability = model.predict_proba(X_test)

    print(f"Actual class: {y_test.iloc[instance]}")
    print(f"Predicted class: {preds[instance]}")

    # SHAP Plots
    shap.plots.bar(shap_values)   # global expalanation
    shap.plots.bar(shap_values[instance], max_display=13)   # local explanation

    # force plot
    return shap.plots.force(shap_values[instance])

Using this code is simple AF. Just call the function like below.

shap_plot(XGBClassifier(), 7) # change the instance number according to your liking
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment