Last active
March 22, 2024 01:22
-
-
Save Shiina18/be34ef49ff634f5a564e78a1335bcb62 to your computer and use it in GitHub Desktop.
Very verbose implementation for finding topk with minimum comparison "described in Knuth's Art of Programming, Volume 3, Page 212". See https://stackoverflow.com/questions/4956593/optimal-algorithm-for-returning-top-k-values-from-an-array-of-length-n
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Very verbose implementation for finding topk with minimum comparison | |
https://stackoverflow.com/questions/4956593/optimal-algorithm-for-returning-top-k-values-from-an-array-of-length-n | |
""" | |
from __future__ import annotations | |
import dataclasses | |
import enum | |
class Result(enum.IntEnum): | |
LEFT = 0 | |
RIGHT = 1 | |
@dataclasses.dataclass | |
class Node: | |
value: int = None | |
left: 'Node' = None | |
right: 'Node' = None | |
next_winner: 'Node' = None | |
prev_winner: 'Node' = None | |
class Solution: | |
def __init__(self): | |
self.n_duels = 0 | |
def duel(self, node_left: Node, node_right: Node) -> Result: | |
if node_right.value is None: | |
return Result.LEFT | |
if node_left.value is None: | |
return Result.RIGHT | |
self.n_duels += 1 | |
if node_left.value < node_right.value: | |
return Result.LEFT | |
return Result.RIGHT | |
def find_topk(self, arr: list[int], k=3): | |
"""Assume no duplicate values in arr""" | |
if k <= 1: | |
raise ValueError | |
if len(arr) <= k: | |
return arr | |
topk = [] | |
bench = arr[:k-2] | |
winner = self.get_tourney_winner(arr[k-2:]) | |
topk.append(winner.value) | |
for value in bench: | |
winner = self.sub(value, winner) | |
topk.append(winner.value) | |
value = float('inf') | |
self.n_duels -= 1 # the first duel is auto-lose | |
winner = self.sub(value, winner) | |
topk.append(winner.value) | |
assert sorted(topk) == sorted(arr)[:k] | |
return topk | |
def get_tourney_winner(self, arr) -> Node: | |
leaves = [Node(value=x) for x in arr] | |
nodes_prev_layer = leaves | |
while len(nodes_prev_layer) > 1: | |
nodes_curr_layer = [] | |
for i in range(0, len(nodes_prev_layer), 2): | |
node_left = nodes_prev_layer[i] | |
node_right = Node() | |
if i + 1 < len(nodes_prev_layer): | |
node_right = nodes_prev_layer[i + 1] | |
res = self.duel(node_left, node_right) | |
prev_winner = node_left if res == Result.LEFT else node_right | |
node = Node( | |
value=prev_winner.value, | |
left=node_left, | |
right=node_right, | |
prev_winner=prev_winner, | |
) | |
node_left.next_winner = node | |
node_right.next_winner = node | |
nodes_curr_layer.append(node) | |
nodes_prev_layer = nodes_curr_layer | |
winner = nodes_curr_layer[0] | |
return winner | |
def sub(self, new_value, winner: Node) -> Node: | |
while winner.prev_winner is not None: | |
winner = winner.prev_winner | |
player = winner | |
player.value = new_value | |
while player.next_winner is not None: | |
old_winner = player.next_winner | |
res = self.duel(old_winner.left, old_winner.right) | |
new_winner = old_winner.left if res == Result.LEFT else old_winner.right | |
node = Node( | |
value=new_winner.value, | |
left=old_winner.left, | |
right=old_winner.right, | |
prev_winner=new_winner, | |
next_winner=old_winner.next_winner, | |
) | |
node.left.next_winner = node | |
node.right.next_winner = node | |
if old_winner.next_winner: | |
if old_winner.value == old_winner.next_winner.left.value: | |
old_winner.next_winner.left = node | |
elif old_winner.value == old_winner.next_winner.right.value: | |
old_winner.next_winner.right = node | |
else: | |
raise ValueError | |
player = node | |
winner = player | |
return winner | |
import random | |
import tqdm | |
import numpy as np | |
ns = [] | |
for _ in tqdm.tqdm(range(100000)): | |
arr = list(range(128)) | |
random.shuffle(arr) | |
solution = Solution() | |
solution.find_topk( | |
arr=arr, | |
k=8, | |
) | |
ns.append(solution.n_duels) | |
print("Max:", max(ns), "Min:", min(ns)) | |
mean = np.mean(ns) | |
print("Mean:", mean) | |
std_dev = np.std(ns) | |
print("Standard Deviation:", std_dev) | |
std_err = std_dev / np.sqrt(len(ns)) | |
print("Standard Error:", std_err) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment