Skip to content

Instantly share code, notes, and snippets.

@flutefreak7
Created April 30, 2019 06:03
Show Gist options
  • Save flutefreak7/50ffd291eaa348ead35c9794587006df to your computer and use it in GitHub Desktop.
Save flutefreak7/50ffd291eaa348ead35c9794587006df to your computer and use it in GitHub Desktop.
augment pandas to_latex() with vertical bars, multicolumn midrules, first column alignment options, and header formatting
import re
import numpy as np
def pandas_to_latex(df_table, latex_file, vertical_bars=False, right_align_first_column=True, header=True, index=False,
escape=False, multicolumn=False, **kwargs) -> None:
"""
Function that augments pandas DataFrame.to_latex() capability.
:param df_table: dataframe
:param latex_file: filename to write latex table code to
:param vertical_bars: Add vertical bars to the table (note that latex's booktabs table format that pandas uses is
incompatible with vertical bars, so the top/mid/bottom rules are changed to hlines.
:param right_align_first_column: Allows option to turn off right-aligned first column
:param header: Whether or not to display the header
:param index: Whether or not to display the index labels
:param escape: Whether or not to escape latex commands. Set to false to pass deliberate latex commands yourself
:param multicolumn: Enable better handling for multi-index column headers - adds midrules
:param kwargs: additional arguments to pass through to DataFrame.to_latex()
:return: None
"""
n = len(df_table.columns) + int(index)
if right_align_first_column:
cols = 'r' + 'c' * (n - 1)
else:
cols = 'c' * n
if vertical_bars:
# Add the vertical lines
cols = '|' + '|'.join(cols) + '|'
latex = df_table.to_latex(escape=escape, index=index, column_format=cols, header=header, multicolumn=multicolumn,
**kwargs)
if vertical_bars:
# Remove the booktabs rules since they are incompatible with vertical lines
latex = re.sub(r'\\(top|mid|bottom)rule', r'\\hline', latex)
# Multicolumn improvements - center level 1 headers and add midrules
if multicolumn:
latex = latex.replace(r'{l}', r'{c}')
offset = int(index)
midrule_str = ''
for i, col in enumerate(df_table.columns.levels[0]):
indices = np.nonzero(np.array(df_table.columns.codes[0]) == i)[0]
hstart = 1 + offset + indices[0]
hend = 1 + offset + indices[-1]
midrule_str += rf'\cmidrule(lr){{{hstart}-{hend}}} '
# Ensure that headers don't get colored by row highlighting
midrule_str += r'\rowcolor{white}'
latex_lines = latex.splitlines()
latex_lines.insert(3, midrule_str)
latex = '\n'.join(latex_lines)
with open(latex_file, 'w') as f:
f.write(latex)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment