Skip to content

Instantly share code, notes, and snippets.

@cutthroat
Created May 21, 2010 15:41
Show Gist options
  • Save cutthroat/408990 to your computer and use it in GitHub Desktop.
Save cutthroat/408990 to your computer and use it in GitHub Desktop.
simple avl tree implementation
#include <stdlib.h>
#include <errno.h>
#include <stdio.h>
#include "avltree.h"
struct node {
struct node *left, *right;
int diff;
key_t key;
};
struct avltree {
struct node *root;
};
static inline int max(int a, int b)
{
return a > b ? a : b;
}
static inline int min(int a, int b)
{
return a < b ? a : b;
}
static inline void fix_diffs_right(int *ap, int *bp)
{
int a = *ap, b = *bp, k = (b > 0) * b;
*ap = k + (a - b) + 1;
*bp = max(b, k + a + 1) + 1;
}
static inline void fix_diffs_left(int *ap, int *bp)
{
int a = *ap, b = *bp, k = (b < 0) * b;
*ap = k + (a - b) - 1;
*bp = min(b, k + a - 1) - 1;
}
static inline void rotate_right(struct node **rp)
{
struct node *a = *rp, *b = a->left;
a->left = b->right;
b->right = a;
fix_diffs_right(&a->diff, &b->diff);
*rp = b;
}
static inline void rotate_left(struct node **rp)
{
struct node *a = *rp, *b = a->right;
a->right = b->left;
b->left = a;
fix_diffs_left(&a->diff, &b->diff);
*rp = b;
}
static inline int balance(struct node **rp)
{
struct node *a = *rp;
if (a->diff == 2) {
if (a->right->diff == -1)
rotate_right(&a->right);
rotate_left(rp);
return 1;
} else
if (a->diff == -2) {
if (a->left->diff == 1)
rotate_left(&a->left);
rotate_right(rp);
return 1;
}
return 0;
}
static int insert_leaf(key_t key, struct node **rp)
{
struct node *a = (*rp = malloc(sizeof *a));
if (a == NULL) {
perror(NULL);
exit(errno);
}
a->left = a->right = NULL;
a->diff = 0;
a->key = key;
return 1;
}
static int insert(key_t key, struct node **rp)
{
struct node *a = *rp;
if (a == NULL)
return insert_leaf(key, rp);
if (key == a->key)
return 0;
if (key > a->key)
if (insert(key, &a->right) && (++a->diff) == 1)
return 1;
if (key < a->key)
if (insert(key, &a->left) && (--a->diff) == -1)
return 1;
if (a->diff != 0)
balance(rp);
return 0;
}
static int unlink_left(struct node **rp, struct node **lp) // with *rp != NULL
{
struct node *a = *rp;
if (a->left == NULL) {
*rp = a->right;
*lp = a;
return 1;
}
if (unlink_left(&a->left, lp) && (++a->diff) == 0)
return 1;
if (a->diff != 0)
return balance(rp) && (*rp)->diff == 0;
return 0;
}
static int remove_root(struct node **rp)
{
int delta;
struct node *a = *rp, *b;
if (a->left == NULL || a->right == NULL) {
*rp = a->right == NULL ? a->left : a->right;
free(a);
return 1;
}
delta = unlink_left(&a->right, rp);
b = *rp;
b->left = a->left;
b->right = a->right;
b->diff = a->diff;
free(a);
if (delta && (--b->diff) == 0)
return 1;
if (b->diff != 0)
return balance(rp) && (*rp)->diff == 0;
return 0;
}
static int remove_(key_t key, struct node **rp)
{
struct node *a = *rp;
if (a == NULL)
return 0;
if (key == a->key)
return remove_root(rp);
if (key > a->key)
if (remove_(key, &a->right) && (--a->diff) == 0)
return 1;
if (key < a->key)
if (remove_(key, &a->left) && (++a->diff) == 0)
return 1;
if (a->diff != 0)
return balance(rp) && (*rp)->diff == 0;
return 0;
}
static void free_(struct node *a)
{
if (a == NULL)
return;
free_(a->left);
free_(a->right);
free(a);
}
void avl_insert(key_t key, struct avltree *avl)
{
insert(key, &avl->root);
}
void avl_remove(key_t key, struct avltree *avl)
{
remove_(key, &avl->root);
}
int avl_lookup(key_t key, struct avltree *avl)
{
struct node *a = avl->root;
while (a != NULL)
if (key == a->key)
return 1;
else
a = key > a->key ? a->right : a->left;
return 0;
}
void avl_free(struct avltree *avl)
{
free_(avl->root);
free(avl);
}
struct avltree *avl_make()
{
struct avltree *avl = malloc(sizeof *avl);
if (avl == NULL) {
perror(NULL);
exit(errno);
}
avl->root = NULL;
return avl;
}
#ifndef DS_AVLTREE_H
#define DS_AVLTREE_H
typedef int key_t;
struct avltree;
struct avltree *avl_make();
void avl_free(struct avltree *);
void avl_insert(key_t, struct avltree *);
void avl_remove(key_t, struct avltree *);
int avl_lookup(key_t, struct avltree *);
#endif /* DS_AVLTREE_H */
#include <stdlib.h>
#include <stdio.h>
#include "avltree.c"
#define fail(fmt, ...) \
do { \
printf(fmt "\n", ##__VA_ARGS__); \
exit(1); \
} while (0)
void generate_keys(int n, key_t keys[])
{
int i;
unsigned seed = 1234;
for (i = 0; i < n; ++i)
keys[i] = rand_r(&seed) % n;
}
int valid(struct node *a)
{
int lh, rh, b;
if (a == NULL)
return 0;
lh = valid(a->left);
rh = valid(a->right);
b = rh - lh;
if (b != a->diff)
fail("not ok - balance %d must be %d", a->diff, b);
if (abs(b) > 1)
fail("not ok - balance %d must be less than 2", b);
return max(lh, rh) + 1;
}
void should_be_valid(struct avltree *avl)
{
valid(avl->root);
}
void insert_all(int n, key_t keys[], struct avltree *avl)
{
int i;
for (i = 0; i < n; ++i) {
avl_insert(keys[i], avl);
should_be_valid(avl);
}
}
void remove_all(int n, key_t keys[], struct avltree *avl)
{
int i;
for (i = 0; i < n; ++i) {
avl_remove(keys[i], avl);
should_be_valid(avl);
}
}
void should_lookup(int n, key_t keys[], struct avltree *avl)
{
int i;
for (i = 0; i < n; ++i)
if (!avl_lookup(keys[i], avl))
fail("not ok - lookup %d", keys[i]);
}
void should_not_lookup(int n, key_t keys[], struct avltree *avl)
{
int i;
for (i = 0; i < n; ++i)
if (avl_lookup(keys[i], avl))
fail("not ok - not lookup %d", keys[i]);
}
#define N 2048
int main()
{
key_t keys[N];
struct avltree *avl = avl_make();
generate_keys(N, keys);
insert_all(N, keys, avl);
should_lookup(N, keys, avl);
remove_all(N, keys, avl);
should_not_lookup(N, keys, avl);
avl_free(avl);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment