Skip to content

Instantly share code, notes, and snippets.

@X547
Created August 27, 2024 22:02
Show Gist options
  • Save X547/b3eb6a0e89f015932a367b075ee429ac to your computer and use it in GitHub Desktop.
Save X547/b3eb6a0e89f015932a367b075ee429ac to your computer and use it in GitHub Desktop.
RadixTree
#include "RadixTreeMap.h"
#include <new>
#include <algorithm>
#include <AutoDeleter.h>
#include "BitUtils.h"
#define CHECK_RET(err) {status_t _err = (err); if (_err < B_OK) return _err;}
RadixTreeMapBase::~RadixTreeMapBase()
{
MakeEmpty();
}
uint32
RadixTreeMapBase::CountItems() const
{
// TODO: optimize
uint32 count = 0;
uint32 key = 0;
while (NextUsed(key)) {
count++;
if (key == UINT32_MAX)
break;
key++;
}
return count;
}
bool
RadixTreeMapBase::IsEmpty() const
{
return fRoot == NULL;
}
void
RadixTreeMapBase::MakeEmpty(void (*freeValue)(void* cookie, void* value), void* cookie)
{
_MakeEmpty(fRoot, fDepth, freeValue, cookie);
}
void*
RadixTreeMapBase::Get(uint32 key) const
{
Node* node = fRoot;
for (uint32 level = fDepth; level > 0; level--) {
if (node == NULL)
return NULL;
uint32 idx = GetBits(key, kNodeBits*level, kNodeBits);
node = static_cast<IntNode*>(node)->nodes[idx];
}
if (node == NULL)
return NULL;
uint32 idx = GetBits(key, 0, kNodeBits);
return static_cast<LeafNode*>(node)->values[idx];
}
status_t
RadixTreeMapBase::Put(uint32 key, void* value, void*& oldValue)
{
if (value != NULL) {
CHECK_RET(_GrowDepth(key == 0 ? 0 : (std::bit_width(key) - 1) / kNodeBits));
status_t res = _Insert(fRoot, fDepth, key, value, oldValue);
if (res < B_OK) {
_Remove(fRoot, fDepth, key, oldValue);
_ShrinkDepth();
}
return res;
}
_Remove(fRoot, fDepth, key, oldValue);
_ShrinkDepth();
return B_OK;
}
bool
RadixTreeMapBase::NextFree(uint32& key) const
{
if (_Next(fRoot, fDepth, key, false, _NextFreeIdx, _NextFreeIdx))
return true;
if (key != UINT32_MAX) {
key = 1U << (kNodeBits * (fDepth + 1));
return true;
}
return false;
}
bool
RadixTreeMapBase::NextUsed(uint32& key) const
{
return _Next(fRoot, fDepth, key, true, _NextIntUsedIdx, _NextLeafUsedIdx);
}
status_t
RadixTreeMapBase::_Insert(Node*& node, uint32 level, uint32 key, void* value, void*& oldValue)
{
if (level > 0) {
IntNode* intNode = static_cast<IntNode*>(node);
if (intNode == NULL) {
intNode = new(std::nothrow) IntNode();
if (intNode == NULL)
return B_NO_MEMORY;
node = intNode;
}
uint32 idx = GetBits(key, kNodeBits*level, kNodeBits);
status_t res = _Insert(intNode->nodes[idx], level - 1, key, value, oldValue);
CHECK_RET(res);
if (intNode->nodes[idx]->full == GetMask<uint32>(kNodeCnt))
SetBit(intNode->full, idx);
return res;
}
LeafNode* leafNode = static_cast<LeafNode*>(node);
if (leafNode == NULL) {
leafNode = new(std::nothrow) LeafNode();
if (leafNode == NULL)
return B_NO_MEMORY;
node = leafNode;
}
uint32 idx = GetBits(key, 0, kNodeBits);
SetBit(leafNode->full, idx);
oldValue = leafNode->values[idx];
leafNode->values[idx] = value;
return B_OK;
}
bool
RadixTreeMapBase::_IsEmpty(IntNode* node)
{
for (uint32 i = 0; i < kNodeCnt; i++) {
if (node->nodes[i] != NULL)
return false;
}
return true;
}
void
RadixTreeMapBase::_Remove(Node*& node, uint32 level, uint32 key, void*& oldValue)
{
if (node == NULL) {
oldValue = NULL;
return;
}
if (level > 0) {
IntNode* intNode = static_cast<IntNode*>(node);
uint32 idx = GetBits(key, kNodeBits*level, kNodeBits);
_Remove(intNode->nodes[idx], level - 1, key, oldValue);
ClearBit(intNode->full, idx);
if (_IsEmpty(intNode)) {
delete intNode;
node = NULL;
}
return;
}
LeafNode* leafNode = static_cast<LeafNode*>(node);
uint32 idx = GetBits(key, 0, kNodeBits);
ClearBit(leafNode->full, idx);
oldValue = leafNode->values[idx];
leafNode->values[idx] = NULL;
if (leafNode->full == 0) {
delete leafNode;
node = NULL;
}
}
void
RadixTreeMapBase::_MakeEmpty(Node*& node, uint32 level, void (*freeValue)(void* cookie, void* value), void* cookie)
{
if (node == NULL)
return;
if (level > 0) {
IntNode* intNode = static_cast<IntNode*>(node);
for (uint32 idx = 0; idx < kNodeCnt; idx++) {
Node*& childNode = intNode->nodes[idx];
if (childNode != NULL) {
_MakeEmpty(childNode, level - 1, freeValue, cookie);
}
}
delete intNode;
node = NULL;
return;
}
LeafNode* leafNode = static_cast<LeafNode*>(node);
uint32 idx = 0;
for (;;) {
idx = _NextLeafUsedIdx(leafNode, idx);
if (idx >= kNodeCnt)
break;
void*& value = leafNode->values[idx];
if (freeValue != NULL)
freeValue(cookie, value);
value = NULL;
idx++;
}
delete leafNode;
node = NULL;
}
status_t
RadixTreeMapBase::_GrowDepth(uint32 depth)
{
if (fRoot == NULL) {
fDepth = std::max<uint32>(fDepth, depth);
return B_OK;
}
while (fDepth < depth) {
IntNode* intNode = new(std::nothrow) IntNode();
if (intNode == NULL) {
_ShrinkDepth();
return B_NO_MEMORY;
}
intNode->nodes[0] = fRoot;
if (fRoot->full == GetMask<uint32>(kNodeCnt))
SetBit(intNode->full, 0);
fRoot = intNode;
fDepth++;
}
return B_OK;
}
bool
RadixTreeMapBase::_CanShrink(IntNode* node)
{
for (uint32 i = 1; i < kNodeCnt; i++) {
if (node->nodes[i] != NULL)
return false;
}
return true;
}
void
RadixTreeMapBase::_ShrinkDepth()
{
if (fRoot == NULL) {
fDepth = 0;
return;
}
while (fDepth > 0) {
IntNode* intNode = static_cast<IntNode*>(fRoot);
if (!_CanShrink(intNode))
return;
fRoot = intNode->nodes[0];
delete intNode;
fDepth--;
}
}
uint32
RadixTreeMapBase::_NextFreeIdx(Node* node, uint32 idx)
{
return std::countr_one(node->full | GetMask<uint32>(idx));
}
uint32
RadixTreeMapBase::_NextIntUsedIdx(Node* node, uint32 idx)
{
IntNode* intNode = static_cast<IntNode*>(node);
while ((idx < kNodeCnt) && !(intNode->nodes[idx] != NULL))
idx++;
return idx;
}
uint32
RadixTreeMapBase::_NextLeafUsedIdx(Node* node, uint32 idx)
{
return std::countr_one(~node->full | GetMask<uint32>(idx));
}
bool
RadixTreeMapBase::_Next(Node* node, uint32 level, uint32& key,
bool isUsed,
uint32 (*NextIntIdx)(Node* node, uint32 idx),
uint32 (*NextLeafIdx)(Node* node, uint32 idx))
{
if (node == NULL)
return !isUsed;
if (level > 0) {
IntNode* intNode = static_cast<IntNode*>(node);
uint32 idx = GetBits(key, kNodeBits * level, kNodeBits);
uint32 freeIdx = idx;
for (;;) {
freeIdx = NextIntIdx(node, freeIdx);
if (freeIdx >= kNodeCnt)
return false;
if (freeIdx > idx) {
key = SetBitsTo<uint32>(key, kNodeBits * level, kNodeBits, freeIdx);
key = SetBitsTo<uint32>(key, 0, kNodeBits * level, 0);
}
if (_Next(intNode->nodes[freeIdx], level - 1, key, isUsed, NextIntIdx, NextLeafIdx))
return true;
freeIdx++;
if (freeIdx >= kNodeCnt)
return false;
}
}
uint32 idx = GetBits(key, 0, kNodeBits);
uint32 freeIdx = NextLeafIdx(node, idx);
if (freeIdx >= kNodeCnt)
return false;
key = SetBitsTo(key, 0, kNodeBits, freeIdx);
return true;
}
#pragma once
#include <private/shared/AutoDeleter.h>
class RadixTreeMapBase {
public:
~RadixTreeMapBase();
uint32 CountItems() const;
bool IsEmpty() const;
void MakeEmpty(void (*freeValue)(void* cookie, void* value) = NULL, void* cookie = NULL);
void* Get(uint32 key) const;
status_t Put(uint32 key, void* value, void*& oldValue);
inline status_t Put(uint32 key, void* value)
{void* oldValue; return Put(key, value, oldValue);}
bool NextFree(uint32& key) const;
bool NextUsed(uint32& key) const;
private:
static constexpr uint32 kKeyBits = 32;
static constexpr uint32 kNodeBits = 4;
static constexpr uint32 kNodeCnt = 1 << kNodeBits;
struct Node {
uint32 full {};
};
struct IntNode: public Node {
Node* nodes[kNodeCnt] {};
};
struct LeafNode: public Node {
void* values[kNodeCnt] {};
};
private:
static status_t _Insert(Node*& node, uint32 level, uint32 key, void* value, void*& oldValue);
static bool _IsEmpty(IntNode* node);
static void _Remove(Node*& node, uint32 level, uint32 key, void*& oldValue);
static void _MakeEmpty(Node*& node, uint32 level, void (*freeValue)(void* cookie, void* value), void* cookie);
status_t _GrowDepth(uint32 depth);
static bool _CanShrink(IntNode* node);
void _ShrinkDepth();
static uint32 _NextFreeIdx(Node* node, uint32 idx);
static uint32 _NextIntUsedIdx(Node* node, uint32 idx);
static uint32 _NextLeafUsedIdx(Node* node, uint32 idx);
static bool _Next(Node* node, uint32 level, uint32& key,
bool isUsed,
uint32 (*NextIntIdx)(Node* node, uint32 idx),
uint32 (*NextLeafIdx)(Node* node, uint32 idx));
private:
uint32 fDepth = 0;
Node* fRoot {};
};
template <typename Type, typename Deleter = BPrivate::ObjectDelete<Type>>
class RadixTreeMap {
public:
~RadixTreeMap()
{
MakeEmpty();
}
uint32 CountItems() const
{
return fBase.CountItems();
}
bool IsEmpty() const
{
return fBase.IsEmpty();
}
void MakeEmpty()
{
fBase.MakeEmpty([](void* cookie, void* value) {
(*static_cast<Deleter*>(cookie))(static_cast<Type*>(value));
});
}
Type* Get(uint32 key) const
{
return static_cast<Type*>(fBase.Get(key));
}
status_t Put(uint32 key, Type* value, Type*& oldValue)
{
return fBase.Put(key, value, *(void**)(&oldValue));
}
status_t Put(uint32 key, Type* value)
{
Type* oldValue;
status_t res = Put(key, value, oldValue);
if (res < B_OK)
return res;
if (oldValue != NULL && oldValue != value)
fDeleter(oldValue);
return B_OK;
}
void Remove(uint32 key)
{
Type* oldValue;
Put(key, NULL, oldValue);
if (oldValue != NULL)
fDeleter(oldValue);
}
bool NextFree(uint32& key) const
{
return fBase.NextFree(key);
}
bool NextUsed(uint32& key) const
{
return fBase.NextUsed(key);
}
private:
RadixTreeMapBase fBase;
Deleter fDeleter;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment