Created
October 20, 2023 15:32
-
-
Save iolloyd/50f9fae4c483824d7b753d49db658dfc to your computer and use it in GitHub Desktop.
simple binary shapley values for data records against query output
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import math | |
def query_age(dataset): | |
""" | |
A query function that returns records with age less than 26. | |
""" | |
return [record for record in dataset if record['age'] < 26] | |
def query_trust(dataset): | |
""" | |
A query function that returns records with trust equal to 100. | |
""" | |
return [record for record in dataset if record['trust'] == 100] | |
def shapley_value(dataset, query_func): | |
""" | |
Compute the Shapley value for each record in the dataset based on a given query. | |
""" | |
n = len(dataset) | |
shapley_values = [0] * n | |
# Generate all possible coalitions (subsets) of the dataset | |
for i in range(1, n + 1): | |
for coalition in itertools.combinations(range(n), i): | |
# Calculate the value of the coalition and the value of the coalition without each player | |
best_of_coalition = query_func([dataset[j] for j in coalition]) | |
for j in coalition: | |
best_without_j = query_func([dataset[k] for k in coalition if k != j]) if i > 1 else [] | |
marginal_contribution = 1 if dataset[j] in best_of_coalition and dataset[j] not in best_without_j else 0 | |
shapley_values[j] += marginal_contribution / (math.comb(n, i) * i) | |
return shapley_values | |
# Sample dataset: 3 records with 3 fields | |
dataset = [ | |
{'age': 20, 'value': 1, 'trust': 10}, | |
{'age': 25, 'value': 5, 'trust': 50}, | |
{'age': 30, 'value': 10, 'trust': 100} | |
] | |
# Compute Shapley values for the query 'age < 26' | |
shapley_values_age = shapley_value(dataset, query_age) | |
print(f"Shapley values for query 'age < 26': {shapley_values_age}") | |
# Compute Shapley values for the query 'trust = 100' | |
shapley_values_trust = shapley_value(dataset, query_trust) | |
print(f"Shapley values for query 'trust = 100': {shapley_values_trust}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment