Skip to content

Instantly share code, notes, and snippets.

@lsem
Last active August 5, 2024 20:43
Show Gist options
  • Save lsem/ba4e955941de8fc66115c2ba072d0a9d to your computer and use it in GitHub Desktop.
Save lsem/ba4e955941de8fc66115c2ba072d0a9d to your computer and use it in GitHub Desktop.
lfu_cache.cpp
#include <cassert>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <unordered_map>
using namespace std;
class LFUCache {
public:
explicit LFUCache(int capacity) : m_capacity(capacity) {}
int get(int key) const {
if (auto it = m_map.find(key); it != m_map.end()) {
auto it2 = touch(it->second)->val;
return it2;
}
return -1;
}
void put(int key, int value) {
if (auto it = m_map.find(key); it != m_map.end()) {
it->second->val = value;
touch(it->second);
} else {
if (m_map.size() == m_capacity) {
m_map.erase(m_usages.back().key);
auto back_it = std::prev(m_usages.end());
if (back_it != m_usages.begin()) {
back_it->range_node->right_it--;
} else {
// a group with single element
}
m_usages.pop_back();
}
m_usages.push_back({key, value});
auto added_item_it = std::prev(m_usages.end());
onboard_new_item(added_item_it);
m_map.insert({key, added_item_it});
}
}
private:
struct range_node_t;
struct keyval_t {
int key{}, val{}, freq{};
shared_ptr<range_node_t> range_node;
};
// A range that is shared between all nodes that are part of the range so that
// they can quickly find next range.
struct range_node_t {
list<keyval_t>::iterator left_it, right_it;
int freq{};
};
void onboard_new_item(list<keyval_t>::iterator it) {
// This can possibly implemented somehow in terms of touch operation, we
// just don't need to leave anything but joining may be the same.
assert(it == std::prev(m_usages.end()));
assert(it->freq == 0);
if (it == m_usages.begin() || std::prev(it)->freq != 0) {
// need to create new range node representing this group.
it->range_node = std::make_shared<range_node_t>();
it->range_node->left_it = it;
it->range_node->right_it = it;
it->range_node->freq = 0;
} else {
// We have at least one more existing element and it has frequency 0.
auto joining_group_range_node = std::prev(it)->range_node;
it->range_node = std::prev(it)->range_node;
m_usages.splice(it->range_node->left_it, m_usages, it);
it->range_node->left_it = it;
}
}
list<keyval_t>::iterator touch(list<keyval_t>::iterator it) const {
// This gives us linear performance, we are going to speed up by traversing
// list to correct position in O(1).
// while (it != m_usages.begin() && it->freq >= std::prev(it)->freq) {
// cout << i++ << " swapped!\n";
// m_usages.splice(std::prev(it), m_usages, it);
// }
it->freq++;
auto l_it = it->range_node->left_it;
// Save prev range which we will leave anyways as a result of the touch
// operation.
auto prev_range = it->range_node;
auto new_right = prev_range->right_it != m_usages.begin()
? std::prev(prev_range->right_it)
: m_usages.end();
auto new_left = prev_range->left_it != m_usages.end()
? std::next(prev_range->left_it)
: m_usages.end();
if (l_it == m_usages.begin() || std::prev(l_it)->freq != it->freq) {
it->range_node = std::make_shared<range_node_t>();
it->range_node->freq = it->freq;
it->range_node->left_it = it;
it->range_node->right_it = it;
m_usages.splice(l_it, m_usages, it);
} else {
auto old_range = it->range_node;
auto existing_range_it = std::prev(l_it);
assert(existing_range_it->freq == it->freq);
m_usages.splice(existing_range_it->range_node->left_it, m_usages, it);
existing_range_it->range_node->left_it = it;
it->range_node = existing_range_it->range_node;
}
// Leave the previous range
if (it == prev_range->left_it && it == prev_range->right_it) {
// was a single memeber of the range, nothing to do, it will be
// deleted.
} else if (it == prev_range->left_it) {
prev_range->left_it = new_left;
} else if (it == prev_range->right_it) {
prev_range->right_it = new_right;
}
return it;
}
private:
int m_capacity{};
using iterator = list<keyval_t>::iterator;
unordered_map<int, iterator> m_map;
mutable list<keyval_t> m_usages;
};
///////////////////////////////////////////////////////////////////////////////
#define GREEN "\033[0;32m"
#define RED "\033[0;31m"
#define NC "\033[0m"
#define expect_eq(Expression, Expected) \
do { \
auto actual = (Expression); \
auto expected = (Expected); \
if (!(actual == expected)) { \
std::cout << RED "FAILED: " NC << __FILE__ << ":" << __LINE__ << ": " \
<< #Expression << " != " << #Expected << ", Actual: \"" \
<< actual << "\" \n"; \
} else { \
std::cout << GREEN "PASSED" NC << ": " << #Expression \
<< " == " << #Expected << "\n"; \
} \
} while (false)
///////////////////////////////////////////////////////////////////////////////
void test_case1() {
LFUCache c{2};
c.put(1, 1);
c.put(2, 2);
expect_eq(c.get(1), 1);
c.put(3, 3);
expect_eq(c.get(2), -1);
expect_eq(c.get(3), 3);
c.put(4, 4);
expect_eq(c.get(1), -1);
expect_eq(c.get(3), 3);
expect_eq(c.get(4), 4);
}
void test_case2() {
LFUCache c{3};
c.put(1, 1);
c.put(2, 2);
c.put(3, 3);
c.put(4, 4);
expect_eq(c.get(4), 4);
expect_eq(c.get(3), 3);
expect_eq(c.get(2), 2);
expect_eq(c.get(1), -1);
}
void test_case3() {
LFUCache c{3};
c.put(1, 1);
c.put(2, 2);
c.put(3, 3);
c.put(4, 4);
expect_eq(c.get(4), 4);
expect_eq(c.get(3), 3);
expect_eq(c.get(2), 2);
expect_eq(c.get(1), -1);
c.put(5, 5); // expected: 3, 2, 5
expect_eq(c.get(1), -1);
expect_eq(c.get(2), 2); // expected: 2, 3, 5
expect_eq(c.get(3), 3); // expected: 3, 2, 5
expect_eq(c.get(4), -1);
expect_eq(c.get(5), 5);
}
void test_case4() {
LFUCache c{10};
c.put(10, 13);
c.put(3, 17);
c.put(6, 11);
c.put(10, 5);
c.put(9, 10);
c.get(13);
c.put(2, 19);
c.get(2);
c.get(3);
c.put(5, 25);
c.get(8);
c.put(9, 22);
c.put(5, 5);
c.put(1, 30);
c.get(11);
c.put(9, 12);
c.get(7);
c.get(5);
c.get(8);
c.get(9);
c.put(4, 30);
c.put(9, 3);
c.get(9);
c.get(10);
c.get(10);
c.put(6, 14);
c.put(3, 1);
c.get(3);
c.put(10, 11);
c.get(8);
c.put(2, 14);
c.get(1);
c.get(5);
c.get(4);
c.put(11, 4);
c.put(12, 24);
c.put(5, 18);
c.get(13);
c.put(7, 23);
c.get(8);
c.get(12);
c.put(3, 27);
c.put(2, 12);
c.get(5);
c.put(2, 9);
c.put(13, 4);
c.put(8, 18);
c.put(1, 7);
c.get(6);
c.put(9, 29);
c.put(8, 21);
c.get(5);
c.put(6, 30);
c.put(1, 12);
c.get(10);
c.put(4, 15);
c.put(7, 22);
c.put(11, 26);
c.put(8, 17);
c.put(9, 29);
c.get(5);
c.put(3, 4);
c.put(11, 30);
c.get(12);
c.put(4, 29);
c.get(3);
c.get(9);
c.get(6);
c.put(3, 4);
c.get(1);
c.get(10);
c.put(3, 29);
c.put(10, 28);
c.put(1, 20);
c.put(11, 13);
c.get(3);
c.put(3, 12);
c.put(3, 8);
c.put(10, 9);
c.put(3, 26);
c.get(8);
c.get(7);
c.get(5);
c.put(13, 17);
c.put(2, 27);
c.put(11, 15);
c.get(12);
c.put(9, 19);
c.put(2, 15);
c.put(3, 16);
c.get(1);
c.put(12, 17);
c.put(9, 1);
c.put(6, 19);
c.get(4);
c.get(5);
c.get(5);
c.put(8, 1);
c.put(11, 7);
c.put(5, 2);
c.put(9, 28);
c.get(1);
c.put(2, 2);
c.put(7, 4);
c.put(4, 22);
c.put(7, 24);
c.put(9, 26);
c.put(13, 28);
c.put(11, 26);
}
int main() {
test_case1();
test_case2();
test_case3();
test_case4();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment