Skip to content

Instantly share code, notes, and snippets.

@starwing
Last active August 13, 2019 14:41
Show Gist options
  • Save starwing/e4d8f3811f4eab3226a74ba11c6fc4c5 to your computer and use it in GitHub Desktop.
Save starwing/e4d8f3811f4eab3226a74ba11c6fc4c5 to your computer and use it in GitHub Desktop.
新写的红黑树……求轻喷……
#include <assert.h>
#include <stdlib.h>
#include <limits.h>
typedef struct rb_Node rb_Node;
struct rb_Node {
rb_Node *parent, *k[2];
unsigned size : sizeof(unsigned)*CHAR_BIT-1;
unsigned color : 1;
int data;
};
#define PAGE_SIZE 4096
static rb_Node *pages, *freed;
static void put(rb_Node *n) {
n->parent = freed;
freed = n;
}
static rb_Node *get(void) {
rb_Node *n;
if (freed == NULL) {
size_t i;
void *newpage = malloc(PAGE_SIZE);
if (newpage == NULL) return NULL;
*(void**)((char*)newpage + PAGE_SIZE-sizeof(void*)) = pages;
pages = (rb_Node*)newpage;
for (i = 0; i < (PAGE_SIZE-sizeof(void*))/sizeof(rb_Node); ++i)
put(&pages[i]);
assert(freed != NULL);
}
n = freed;
freed = freed->parent;
return n;
}
#define rb_size(n) ((n) ? (n)->size : 0)
#define rb_isblack(n) ((n)->color != 0)
#define rb_setblack(n) ((n)->color=1)
#define rb_clrblack(n) ((n)->color=0)
#define LEFT 0
#define RIGHT 1
static rb_Node *rb_find(rb_Node *n, int data) {
while (n && n->data != data)
n = n->k[data >= n->data];
return n;
}
static rb_Node *rb_next(rb_Node *r, rb_Node *n, int right) {
int left = !right;
if (r == NULL) return NULL;
if (n == NULL || (r = n->k[right])) {
while (r->k[left]) r = r->k[left];
return r;
}
while (n->parent && n == n->parent->k[right])
n = n->parent;
return n->parent ? n->parent->k[right] : NULL;
}
static rb_Node *rb_replace(rb_Node *r, rb_Node *n, rb_Node *t) {
rb_Node *p = n->parent;
if (t) t->parent = p;
if (p) p->k[p->k[RIGHT] == n] = t;
else r = t;
return r;
}
static rb_Node *rb_rotate(rb_Node *r, rb_Node *n, int right) {
rb_Node *t = n->k[!right], *o = t->k[right];
if ((n->k[!right] = o)) o->parent = n;
r = rb_replace(r, n, t);
(t->k[right] = n)->parent = t;
t->size = n->size;
n->size = rb_size(n->k[LEFT]) + rb_size(n->k[RIGHT]) + 1;
return r;
}
static rb_Node *rb_maintain_insert(rb_Node *r, rb_Node *n) {
rb_Node *p, *gp, *u;
while ((p = n->parent) && !rb_isblack(p)) {
int right, uright = (gp = p->parent)->k[RIGHT] != p;
if ((u = gp->k[uright]) && !rb_isblack(u))
rb_setblack(p), rb_setblack(u), rb_clrblack(n = gp);
else {
if ((right = p->k[RIGHT] == n) != (gp->k[RIGHT] == p))
r = rb_rotate(r, n = p, right = !right),
p = n->parent, gp = p->parent;
rb_setblack(p), rb_clrblack(gp);
r = rb_rotate(r, gp, !right);
}
}
rb_setblack(r);
return r;
}
static rb_Node *rb_insert(rb_Node *r, int data) {
rb_Node *n, **pn = &r, *parent = NULL;
while (*pn && (*pn)->data != data)
pn = &(parent = *pn)->k[data >= (*pn)->data];
if (*pn != NULL) return r;
n = *pn = get();
n->parent = parent;
n->k[LEFT] = n->k[RIGHT] = NULL;
n->data = data;
n->color = 0;
n->size = 1;
while (parent != NULL)
++parent->size, parent = parent->parent;
return rb_maintain_insert(r ? r : n, n);
}
static rb_Node *rb_maintain_delete(rb_Node *r, rb_Node *n) {
rb_Node *p, *w, *wl, *wr;
while ((p = n->parent) && rb_isblack(n)) {
int right = p->k[RIGHT] == n;
if (!rb_isblack(w = p->k[!right])) {
rb_setblack(w), rb_clrblack(p);
r = rb_rotate(r, p, right);
w = p->k[!right];
}
wl = w->k[right], wr = w->k[!right];
if ((!wl || rb_isblack(wl)) && (!wr || rb_isblack(wr)))
rb_clrblack(w), n = p;
else {
if (!wr || rb_isblack(wr)) {
rb_clrblack(w), rb_setblack(wl);
r = rb_rotate(r, w, !right);
w = p->k[!right], wr = w->k[!right];
}
w->color = p->color, rb_setblack(p);
if (wr) rb_setblack(wr);
n = r = rb_rotate(r, p, right);
}
}
rb_setblack(n);
return r;
}
static rb_Node *rb_remove(rb_Node *r, rb_Node *n) {
rb_Node *x, *y = n, *c;
if (n->k[LEFT] && n->k[RIGHT]) {
y = n->k[LEFT];
while (y->k[RIGHT]) y = y->k[RIGHT];
}
x = y->k[y->k[RIGHT] != NULL];
r = rb_maintain_delete(r, x ? x : y);
r = rb_replace(r, y, x);
for (c = y; c; c = c->parent)
--c->size;
if (y != n) {
r = rb_replace(r, n, y);
if ((c = y->k[LEFT] = n->k[LEFT])) c->parent = y;
if ((c = y->k[RIGHT] = n->k[RIGHT])) c->parent = y;
y->size = n->size, y->color = n->color;
y = n;
}
put(y);
return r;
}
#include <stdio.h>
#include <time.h>
#define N 1000000
int status[N], data[N], idx[N];
unsigned seed = 1;
static int my_rand() {
unsigned new_seed = seed * 1103515245 + 12345;
seed = new_seed;
return new_seed & ((1u<<31)-1);
}
static int rb_check_recursive(rb_Node *n) {
int left, right;
if (n == NULL) return 1;
left = rb_check_recursive(n->k[LEFT]);
right = rb_check_recursive(n->k[RIGHT]);
assert(rb_size(n) == rb_size(n->k[LEFT]) + rb_size(n->k[RIGHT]) + 1);
assert(!n->k[LEFT] || n->k[LEFT]->parent == n);
assert(!n->k[RIGHT] || n->k[RIGHT]->parent == n);
if (!rb_isblack(n)) {
assert(!n->k[LEFT] || rb_isblack(n->k[LEFT]));
assert(!n->k[RIGHT] || rb_isblack(n->k[RIGHT]));
}
if (left != right) assert(!"black node not equal");
return left + rb_isblack(n);
}
static void rb_check(rb_Node *r) {
rb_Node *n = NULL;
int data;
if (r == NULL) return;
assert(rb_isblack(r));
rb_check_recursive(r);
data = (n = rb_next(r, NULL, RIGHT))->data;
while ((n = rb_next(r, n, RIGHT)) != NULL) {
assert(data < n->data);
data = n->data;
}
}
static void print_tree(rb_Node *n, int level) {
int i;
for (i = 0; i < level; ++i) {
switch (status[i]) {
case 0: printf(" "); break;
case 1: printf("| "); break;
case 2: printf("|-- "); break;
case 3: printf("`-- "); break;
}
}
if (n) printf("%d (%s)\n", n->data, rb_isblack(n) ? "black" : "red");
else printf("(nil)\n");
if (!n || (!n->k[LEFT] && !n->k[RIGHT]))
return;
if (level >= 1) status[level-1] = status[level-1] == 3 ? 0 : 1;
status[level] = 2;
print_tree(n->k[LEFT], level+1);
status[level] = 3;
print_tree(n->k[RIGHT], level+1);
}
int main(void) {
int i;
clock_t ct;
double total = 0, curr = 0;
rb_Node *n = NULL, *t;
for (i = 0; i < N; ++i) {
idx[i] = i;
data[i] = my_rand();
}
for (i = 0; i < N; ++i) {
int tmp, ti = i + (my_rand() % (N-i));
tmp = idx[i];
idx[i] = idx[ti];
idx[ti] = tmp;
}
printf("insert:\n");
ct = clock();
for (i = 0; i < N; ++i) {
n = rb_insert(n, data[i]);
assert(rb_find(n, data[i]) != NULL);
/*rb_check(n);*/
}
printf("insert: %.3fms\n", curr = (double)(clock() - ct)/CLOCKS_PER_SEC*1000);
total += curr;
printf("data[0]: %p\n", (void*)rb_find(n, data[0]));
n = rb_insert(n, data[0]);
printf("data[0]: %p\n", (void*)rb_find(n, data[0]));
rb_check(n);
printf("delete:\n");
ct = clock();
for (i = 0; i < N; ++i) {
t = rb_find(n, data[idx[i]]);
assert(t);
n = rb_remove(n, t);
/*rb_check(n);*/
}
printf("delete: %.3fms\n", curr = (double)(clock() - ct)/CLOCKS_PER_SEC*1000);
total += curr;
assert(n == NULL);
print_tree(n, 0);
rb_check(n);
printf("total time: %.3fms\n", total);
return 0;
}
/* cc: flags+='-ggdb -O2 -Wall -Wextra -std=c90 -pedantic --coverage' */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment