Skip to content

Instantly share code, notes, and snippets.

@botcs
Created February 7, 2024 21:59
Show Gist options
  • Save botcs/8157060321e39baaf9862f5242f28e44 to your computer and use it in GitHub Desktop.
Save botcs/8157060321e39baaf9862f5242f28e44 to your computer and use it in GitHub Desktop.
class FinalScoreAnnotator:
"""
1) Adds annotations of the last values of the curves
2) Aligns the annotations of the plot to avoid overlaps using L-BFGS-B
Usage:
Has the same signature and functionality as the plt.plot or ax.plot
"""
def __init__(self, ax=None):
self.ax = ax or plt.gca()
self.lines = []
self.last_ys = []
self.last_xs = []
self.colors = []
self.markers = [
"*",
"d",
"X",
"P",
"^",
"v",
"^",
]
def plot(self, x, y, label=None, color=None, marker=None, **kwargs):
line_idx = len(self.lines)
if color is None:
color = f"C{line_idx}"
if marker is None:
marker = self.markers[line_idx]
line, = self.ax.plot(
x, y,
color=color,
**kwargs
)
self.ax.plot(
x[-1], y[-1],
color=color,
marker=marker,
markersize=10,
markeredgecolor="black",
markeredgewidth=0.5,
alpha=0.8,
label=label,
**kwargs
)
self.lines.append(line)
self.last_ys.append(y[-1])
self.last_xs.append(x[-1])
self.colors.append(color)
return line
def get_fontsize_in_data_coords(self, ax, fontsize):
"""
Given a Matplotlib Axes object and a font size,
returns the equivalent font height in data coordinates.
Parameters:
- ax: Matplotlib Axes object
- fontsize: Font size
Returns:
- text_height_data_coords: Font height in data coordinates
"""
# Add temporary text to the plot to measure its height
temp_text = ax.text(0, 0, "99.9", fontsize=fontsize)
# Draw the canvas to make sure all elements are rendered
plt.draw()
# Get the bounding box of the text
bbox = temp_text.get_window_extent()
# Convert the bounding box to data coordinates
bbox_data = bbox.transformed(ax.transData.inverted())
# Calculate text height in data coordinates
text_height_data_coords = bbox_data.ymax - bbox_data.ymin
text_width_data_coords = bbox_data.xmax - bbox_data.xmin
# Remove the temporary text
temp_text.remove()
# Redraw the canvas to reflect the removal
plt.draw()
return text_width_data_coords, text_height_data_coords
def annotate(self, fontsize=20, **kwargs):
def objective(x, original_positions, heights):
# Objective function to minimize the total "cost" (distance from original position)
cost = np.sum((x - original_positions)**2)
# Calculate overlaps and add to cost if there are overlaps
for i in range(len(x)):
for j in range(i+1, len(x)):
overlap = heights[i] / 2 + heights[j] / 2 - np.abs(x[i] - x[j])
if overlap > 0:
cost += 10**6 * overlap ** 2 # Large penalty for overlaps
return cost
def correct_positions(original_positions, heights):
if len(original_positions) != len(heights):
return "Lengths of original_positions and widths should be the same."
initial_guess = np.array(original_positions)
eps = 1e-6
random_offset = np.random.uniform(-eps, eps, len(initial_guess))
initial_guess += random_offset
result = minimize(objective, initial_guess, args=(original_positions, heights), method='L-BFGS-B')
return result.x.tolist()
text_width, text_height = self.get_fontsize_in_data_coords(self.ax, fontsize*1.2)
# correct the positions
corrected_positions = correct_positions(
self.last_ys, [text_height]*len(self.last_ys)
)
pos_x = max(self.last_xs)
max_pos_y = max(corrected_positions)
# get current xlim
# get the new width and height using the new xlim and ylim
# text_width, text_height = self.get_fontsize_in_data_coords(self.ax, fontsize*1.2)
for i, (last_y, pos_y) in enumerate(zip(self.last_ys, corrected_positions)):
non_strippable_space = "\hspace{1cm}"
self.ax.annotate(
f"{last_y:>5.1f}".replace(" ", non_strippable_space),
xy=(pos_x+text_width*.2, pos_y),
fontsize=fontsize,
color=self.colors[i],
va="center",
ha="left",
)
ylim = self.ax.get_ylim()
xlim = self.ax.get_xlim()
plt.xlim(xlim[0], max(xlim[1], pos_x+text_width*1.2))
plt.ylim(ylim[0], max(max_pos_y+text_height*1.2, ylim[1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment