# hat - flat hash table mark 2

# TODO try an intern hash + value loc hash combo = fast?
# intern hash: key, no value : set
# key loc hash/set: no hash func / simple, don't store hash

# for custom hash iterators, too complex now to return multiple values,
# so use internal macros or a nice loop wrapper macro for it.

# cz - void type, uses no space, for degenerate template instances
# such as set vs hash

# TODO detect if hash is full, and double / rehash

use b

struct hat:
	int size
	int count
	hat_entry *p0
	hat_entry *p1
	hash_fn *hash
	cmp_fn *cmp

struct hat_entry:
	hash_t hash
	void *key
	void *val

def hat_default_size 101
def hat_full_per_256 128

# I assume size < 2^32, for bigger, better use an mmap'd file or libdb!
def hat_init(d)
	hat_init_prime(d, hat_default_size)
def hat_init(d, size)
	hat_init(d, cstr_hash, cstr_cmp, size)
def hat_init(d, hash, cmp)
	hat_init_prime(d, hash, cmp, hat_default_size)
def hat_init(d, hash, cmp, size)
	hat_init_prime(d, hash, cmp, prime_2pow_32(size))

# note: hat_init_prime: size MUST be prime
def hat_init_prime(d, size)
	hat_init_prime(d, cstr_hash, cstr_cmp, size)
hat_init_prime(hat *d, hash_fn *hash, cmp_fn *cmp, size_t size)
	d->size = size
	d->p0 = Valloc(hat_entry, size)
	d->p1 = d->p0 + size
	d->hash = hash
	d->cmp = cmp
	d->count = 0

def hat_hash(h, p, t, key):
	hash_t h = t->hash(key)
	if h == 0:
		h = 1
	hat_entry *p = t->p0 + h % t->size

def hat_step(step, t, h):
	step = h % (t->size-1) + 1

def hat_next(p, t, step):
	p += step
	if p >= t->p1:
		p -= t->size

def hat_p_full(p) p->hash

def hat_match(t, p, h, k) p->hash == h && !t->cmp(p->key, k)

def hat_find(m, t, p, h, k)
	hat_find(m, t, p, h, k, my(step))
def hat_find(m, t, p, h, k, step)
	m = hat_match(t, p, h, k)
	if !m && p->hash:
		int step
		hat_step(step, t, h)
		repeat:
			hat_next(p, t, step)
			m = hat_match(t, p, h, k)
			if m:
				break
			if !p->hash:
				break

def hat_full(t) t->count+1 >= t->size * hat_full_per_256 / 256

# FIXME this doesn't distinguish 'replace' vs 'new' in the return code.
void *hat_put(hat *t, void *key, void *val):
	bit m
	if hat_full(t):
#		warn("hat_full: %d+1 >= %d * %d / 256", t->count, t->size, hat_full_per_256)
		hat_double(t)
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	void *oldval = NULL
	if m:
		oldval = p->val
	else:
		++t->count
	p->hash = h
	p->key = key
	p->val = val
	return oldval

# FIXME this doesn't indicate 'not found' as an error condition
void *hat_get(hat *t, void *key):
	bit m
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	if m:
		return p->val
	return NULL

void *hat_del1(hat *t, void *key):
	bit m
	hat_hash(h, p, t, key)
	hat_find(m, t, p, h, key)
	if m:
		p->hash = 0
		p->key = NULL
		p->val = NULL
		--t->count
	return NULL

# TODO hat_each, hat_double

hat_double(hat *t):
	use(t)
	error("hat full!  hat_double is not implemented yet, sorry!")

