Skip to content

Instantly share code, notes, and snippets.

@RichardLitt
Created March 27, 2024 23:12
Show Gist options
  • Save RichardLitt/fb7cd6d8efebd8dc6280a3019111e610 to your computer and use it in GitHub Desktop.
Save RichardLitt/fb7cd6d8efebd8dc6280a3019111e610 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
# Load Iris dataset
iris = load_penguins()
# Set up the figure size and subplots
fig, axarr = plt.subplots(4, 4, figsize=(12, 12)) # 4x4 grid for 4 features
# Loop through each grid cell
for x in range(4):
for y in range(4):
ax = axarr[x, y]
if x == y:
# Display feature name on the diagonal
ax.text(0.5, 0.5, iris.feature_names[x], horizontalalignment='center', verticalalignment='center')
ax.set_xticks([])
ax.set_yticks([])
else:
# Scatter plot for feature combinations
ax.scatter(iris.data[:, x], iris.data[:, y], c=iris.target, cmap=plt.cm.get_cmap('RdYlBu', 3))
# Formatting the plots
if y == 0:
ax.set_ylabel(iris.feature_names[x])
else:
ax.set_yticks([])
if x == 3:
ax.set_xlabel(iris.feature_names[y])
else:
ax.set_xticks([])
# Add a color bar to the right of the plots
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
formatter = plt.FuncFormatter(lambda i, *args: iris.target_names[int(i)])
plt.colorbar(cm.ScalarMappable(norm=plt.Normalize(-0.5, 2.5), cmap=plt.cm.get_cmap('RdYlBu', 3)), cax=cbar_ax, ticks=[0, 1, 2], format=formatter)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment