#!/usr/local/bin/cz --

use b

struct hat:
	int size
	int count
	hat_entry *p0
	hat_entry *p1
	hash_fn *hash
	cmp_fn *cmp

struct hat_entry:
	hash_t hash
	void *key
	void *val

def hat_default_size 101
def hat_full_per_256 128

# I assume size < 2^32
def hat_init(d)
	hat_init_prime(d, hat_default_size)
def hat_init(d, size)
	hat_init(d, cstr_hash, cstr_cmp, size)
def hat_init(d, hash, cmp)
	hat_init_prime(d, hash, cmp, hat_default_size)
def hat_init(d, hash, cmp, size)
	hat_init_prime(d, hash, cmp, prime_2pow_32(size))

# note: hat_init_prime: size MUST be prime
def hat_init_prime(d, size)
	hat_init_prime(d, cstr_hash, cstr_cmp, size)
hat_init_prime(hat *d, hash_fn *hash, cmp_fn *cmp, size_t size)
	d->size = size
	d->count = 0
	d->p0 = Valloc(hat_entry, size)
	d->p1 = d->p0 + size
	d->hash = hash
	d->cmp = cmp

def hat_hash(h, p, t, key):
	hash_t h = t->hash(key)
	if h == 0:
		h = 1
	hat_entry *p = t->p0 + h % t->size

def hat_step(step, t, h):
	step = h % (t->size-1) + 1

def hat_next(p, t, step):
	p += step
	if p >= t->p1:
		p -= t->size

def hat_p_full(p) p->hash

def hat_match(t, p, h, k) p->hash == h && !t->cmp(p->key, k)

def hat_find(m, t, p, h, k)
	repeat:
		m = hat_match(t, p, h, k)
		if m:
			break
		if !p->hash:
			break

def hat_full(h) h->count+1 >= h->size * hat_full_per_256 / 256

# FIXME this doesn't distinguish 'replace' vs 'new' in the return code.
void *hat_put(hat *t, void *key, void *val):
	bit m
	if hat_full(h):
		hat_double(h)
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	void *oldval = NULL
	if m:
		oldval = p->val
	else:
		++h->count
	p->hash = h
	p->key = key
	p->val = val
	return oldval

# FIXME this doesn't indicate 'not found' as an error condition
void *hat_get(hat *t, void *key):
	bit m
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	if m:
		return p->val
	return NULL

void *hat_del1(hat *t, void *key):
	bit m
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	if m:
		p->hash = 0
		p->key = NULL
		p->val = NULL
		--h->count
	return NULL


Main:
	hat_test()
	flat_test()
	hash_test()

flat_test:
	bm_start()
	new(d, flat)
	int i = 0
	repeat(1000000):
		flat_put(d, "hello", "world")
		flat_put(d, "goodnight", "sam")
		void *x = flat_get(d, "hello")
		void *y = flat_get(d, "goodnight")
		flat_del1(d, "hello")
		flat_del1(d, "goodnight")
		i += p2i(x) + p2i(y)
	pr(int, i)
	bm("flat")

hash_test:
	bm_start()
	new(d, hashtable)
	int i = 0
	repeat(1000000):
		put(d, "hello", "world")
		put(d, "goodnight", "sam")
		void *x = get(d, "hello")
		void *y = get(d, "goodnight")
		del(d, "hello")
		del(d, "goodnight")
		i += p2i(x) + p2i(y)
	pr(int, i)
	bm("hashtable")

hat_test:
	bm_start()
	new(d, hat)
	int i = 0
	repeat(1000000):
		hat_put(d, "hello", "world")
		hat_put(d, "goodnight", "sam")
		void *x = hat_get(d, "hello")
		void *y = hat_get(d, "goodnight")
		hat_del1(d, "hello")
		hat_del1(d, "goodnight")
		i += p2i(x) + p2i(y)
	pr(int, i)
	bm("hat")


