Created
February 7, 2024 21:59
-
-
Save botcs/8157060321e39baaf9862f5242f28e44 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FinalScoreAnnotator: | |
""" | |
1) Adds annotations of the last values of the curves | |
2) Aligns the annotations of the plot to avoid overlaps using L-BFGS-B | |
Usage: | |
Has the same signature and functionality as the plt.plot or ax.plot | |
""" | |
def __init__(self, ax=None): | |
self.ax = ax or plt.gca() | |
self.lines = [] | |
self.last_ys = [] | |
self.last_xs = [] | |
self.colors = [] | |
self.markers = [ | |
"*", | |
"d", | |
"X", | |
"P", | |
"^", | |
"v", | |
"^", | |
] | |
def plot(self, x, y, label=None, color=None, marker=None, **kwargs): | |
line_idx = len(self.lines) | |
if color is None: | |
color = f"C{line_idx}" | |
if marker is None: | |
marker = self.markers[line_idx] | |
line, = self.ax.plot( | |
x, y, | |
color=color, | |
**kwargs | |
) | |
self.ax.plot( | |
x[-1], y[-1], | |
color=color, | |
marker=marker, | |
markersize=10, | |
markeredgecolor="black", | |
markeredgewidth=0.5, | |
alpha=0.8, | |
label=label, | |
**kwargs | |
) | |
self.lines.append(line) | |
self.last_ys.append(y[-1]) | |
self.last_xs.append(x[-1]) | |
self.colors.append(color) | |
return line | |
def get_fontsize_in_data_coords(self, ax, fontsize): | |
""" | |
Given a Matplotlib Axes object and a font size, | |
returns the equivalent font height in data coordinates. | |
Parameters: | |
- ax: Matplotlib Axes object | |
- fontsize: Font size | |
Returns: | |
- text_height_data_coords: Font height in data coordinates | |
""" | |
# Add temporary text to the plot to measure its height | |
temp_text = ax.text(0, 0, "99.9", fontsize=fontsize) | |
# Draw the canvas to make sure all elements are rendered | |
plt.draw() | |
# Get the bounding box of the text | |
bbox = temp_text.get_window_extent() | |
# Convert the bounding box to data coordinates | |
bbox_data = bbox.transformed(ax.transData.inverted()) | |
# Calculate text height in data coordinates | |
text_height_data_coords = bbox_data.ymax - bbox_data.ymin | |
text_width_data_coords = bbox_data.xmax - bbox_data.xmin | |
# Remove the temporary text | |
temp_text.remove() | |
# Redraw the canvas to reflect the removal | |
plt.draw() | |
return text_width_data_coords, text_height_data_coords | |
def annotate(self, fontsize=20, **kwargs): | |
def objective(x, original_positions, heights): | |
# Objective function to minimize the total "cost" (distance from original position) | |
cost = np.sum((x - original_positions)**2) | |
# Calculate overlaps and add to cost if there are overlaps | |
for i in range(len(x)): | |
for j in range(i+1, len(x)): | |
overlap = heights[i] / 2 + heights[j] / 2 - np.abs(x[i] - x[j]) | |
if overlap > 0: | |
cost += 10**6 * overlap ** 2 # Large penalty for overlaps | |
return cost | |
def correct_positions(original_positions, heights): | |
if len(original_positions) != len(heights): | |
return "Lengths of original_positions and widths should be the same." | |
initial_guess = np.array(original_positions) | |
eps = 1e-6 | |
random_offset = np.random.uniform(-eps, eps, len(initial_guess)) | |
initial_guess += random_offset | |
result = minimize(objective, initial_guess, args=(original_positions, heights), method='L-BFGS-B') | |
return result.x.tolist() | |
text_width, text_height = self.get_fontsize_in_data_coords(self.ax, fontsize*1.2) | |
# correct the positions | |
corrected_positions = correct_positions( | |
self.last_ys, [text_height]*len(self.last_ys) | |
) | |
pos_x = max(self.last_xs) | |
max_pos_y = max(corrected_positions) | |
# get current xlim | |
# get the new width and height using the new xlim and ylim | |
# text_width, text_height = self.get_fontsize_in_data_coords(self.ax, fontsize*1.2) | |
for i, (last_y, pos_y) in enumerate(zip(self.last_ys, corrected_positions)): | |
non_strippable_space = "\hspace{1cm}" | |
self.ax.annotate( | |
f"{last_y:>5.1f}".replace(" ", non_strippable_space), | |
xy=(pos_x+text_width*.2, pos_y), | |
fontsize=fontsize, | |
color=self.colors[i], | |
va="center", | |
ha="left", | |
) | |
ylim = self.ax.get_ylim() | |
xlim = self.ax.get_xlim() | |
plt.xlim(xlim[0], max(xlim[1], pos_x+text_width*1.2)) | |
plt.ylim(ylim[0], max(max_pos_y+text_height*1.2, ylim[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment