Skip to content

Instantly share code, notes, and snippets.

@nalvared
Created October 31, 2017 17:41
Show Gist options
  • Save nalvared/abdce13fff49486ab13c7a2cdd894069 to your computer and use it in GitHub Desktop.
Save nalvared/abdce13fff49486ab13c7a2cdd894069 to your computer and use it in GitHub Desktop.
This is a minimal utility for creating a "beauty" printing of seaborn pairplot function. It remove self comparisons and split lines into an specified number of items

PlotHelper

Import library

The plothelper.py file should be in the same folder

import numpy as np
import pandas as pd
from plothelper import PlotHelper as PH
%matplotlib inline

Load dataset with pandas

# load dataset
dfHouses = pd.read_csv('kc_house_data.csv')

Using PH.pairplot

This is the same function that seaborn.pairplot() but it has two new parameters:

  • max_per_row: maximum charts in the same row
  • reg_line_color: if the type of the chart is kind='reg', this property paints the regression line in the color choosen
PH().pairplot(data=dfHouses,
              x_vars=dfHouses.columns.tolist(),
              y_vars=['price'],
              max_per_row=5,
              kind='reg',
              reg_line_color='red')
import numpy as np
import seaborn as sns
class PlotHelper:
def splitVariables(self, x_vars=None, y_vars=None, max_per_row=10):
slices = []
for y in y_vars:
curr_y = y
l = []
if max_per_row > len(x_vars):
for x in x_vars:
if y != x:
l.append(x)
slices.append([curr_y, l])
else:
i = 0
for s in range(int(np.ceil(len(x_vars)/max_per_row))):
l = []
for x in x_vars[i:max_per_row+i]:
if y != x:
l.append(x)
slices.append([l,curr_y])
i += max_per_row
return slices
def pairplot(self, data=None, hue=None, hue_order=None, palette=None,
vars=None, x_vars=None, y_vars=None, kind='scatter',
diag_kind='hist', markers=None, size=2.5, aspect=1,
dropna=True, plot_kws=None, diag_kws=None,
grid_kws=None, wrap=True, max_per_row=None, reg_line_color=None):
if kind == 'reg' and reg_line_color != None:
plot_kws={'line_kws':{'color':reg_line_color}}
if max_per_row == None:
return sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette,
vars=vars, x_vars=x_vars, y_vars=y_vars, kind=kind,
diag_kind=diag_kind, markers=markers, size=size, aspect=aspect,
dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)
else:
slices = self.splitVariables(x_vars, y_vars, max_per_row)
for i in range(len(slices)):
sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette,
vars=vars, x_vars=slices[i][0], y_vars=slices[i][1], kind=kind,
diag_kind=diag_kind, markers=markers, size=size, aspect=aspect,
dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment