#!/usr/local/bin/cz --
use b

cstr freens_dir = "/freens"
cstr bind_dir = "/etc/bind/freens"
cstr named_conf = "/etc/bind/named.conf.freens"

cstr bind_reload = "/usr/sbin/rndc reload"

cstr header =
 "$TTL	300	; 1 minute\n"
 "@		IN	SOA	ns1.nipl.net. root.nipl.net. (\n"
 "				%d	; Serial\n"
 "				3600		; Refresh	 1 hour\n"
 "				1200		; Retry		20 minutes\n"
 "				2419200		; Expire	 4 weeks\n"
 "				300 )		; -ve Cache TTL	 5 min\n"

cstr default_main = "nipl.net"

cstr default_ns =
 "@	NS	ns1.nipl.net.\n"

cstr default_mx =
 "@	MX	10	mx1.nipl.net.\n"

cstr named_conf_template =
 "zone \"%s\" {\n"
 "	type master;\n"
 "	file \"%s/%s\";\n"
 "};\n"

#int ttl = 600
int mx_priority = 10
boolean resolve_all = 0

# TODO code to multi-ping hosts to check what hosts are up or down for
#   round-robin DNS and disable the ones that are down.

#def verbose warn
def verbose void

vec *zone_names

typedef enum { N_CNAME, N_A, N_AAAA, N_NS, N_MX, N_PTR } node_type

struct node
	node *next
	cstr to
	node_type type

struct zone
	vec node_names
	hashtable nodes

zone_init(zone *z)
	init(&z->node_names, vec, cstr, 16)
	init(&z->nodes, hashtable, cstr_hash, (eq_fn *)cstr_eq, nodes_n_buckets)

node_init(node *n, node *next, cstr to)
	n->next = next
	n->to = to
	n->type = target_type(to)

node_type target_type(cstr to)
	size_t l = strlen(to)
	if strspn(to, "0123456789.") == l && strcmp(to, ".") && strchr(to, '.')
		return N_A
	if strspn(to, "0123456789abcdef:") == l && strchr(to, ':')
		return N_AAAA
	return N_CNAME

cstr node_type_cstr(node_type t)
	which t
	N_CNAME	return "CNAME"
	N_A	return "A"
	N_AAAA	return "AAAA"
	N_NS	return "NS"
	N_MX	return "MX"
	N_PTR	return "PTR"
	return NULL

hashtable _zones, *zones = &_zones

int zones_n_buckets = 101
int nodes_n_buckets = 101

time_t serial = -1

FILE *named_conf_s

boolean changed = 0

Main()
	zone_names = Ls(freens_dir)
	init(zones, hashtable, cstr_hash, (eq_fn *)cstr_eq, zones_n_buckets)
	for_vec(i, zone_names, cstr)
		load_zone(*i)
	for_vec(i, zone_names, cstr)
		process_zone_add_main(*i)
	for_vec(i, zone_names, cstr)
		process_zone(*i)
	if !exists(bind_dir)
		Mkdirs(bind_dir)
	named_conf_s = Fopenout(named_conf)
	for_vec(i, zone_names, cstr)
		output_zone(*i)
	Fclose(named_conf_s)
	remove_old_zones()
	if changed
		signal_named()

load_zone(cstr name)
	key_value *kv = kv(zones, name, NULL)
	if !kv->v
		NEW(kv->v, zone)
	zone *z = kv->v

	cstr file = path_cat(freens_dir, name)
	F_in(file)
		stats st
		Fstat(fileno((FILE*)in->data), &st)
		serial = imax(serial, st.st_mtime)
		Eachline(l)
			if among(*l, '#', ';', '\0')
				continue
			cstr to = Strchr(l, '\t')
			*to++ = '\0'
			cstr from = l
			if strpbrk(from, " \t") || strpbrk(to, " \t")
				error("invalid record contains whitespace: %s:[%s] [%s]", name, from, to)
			key_value *kv = kv(&z->nodes, from, NULL)
			New(n, node, kv->v, Strdup(to))
			if !kv->v
				from = Strdup(from)
				kv->k = from
				vec_push(&z->node_names, from)
			kv->v = n

	for_vec(i, &z->node_names, cstr)
		cstr from = *i
		key_value *kv = KV(&z->nodes, from)
		kv->v = list_reverse((list *)kv->v)

	Free(file)

int list_count_in_records(node *l)
	int c = 0
	for_list(n, &l)
		node_type type = n->type
		if among(type, N_CNAME, N_A, N_AAAA)
			++c
	return c

process_zone_add_main(cstr name)
	zone *z = Get(zones, name)

	# check if it has a main node. if not, add it

	int count = vec_get_size(&z->node_names)

	boolean has_main = 0
	for(i, 0, count)
		cstr from = *(cstr *)v(&z->node_names, i)
		if cstr_eq(from, ".")
			has_main = 1
			break
	if !has_main
		New(n, node, NULL, Strdup(default_main))
		vec_push(&z->node_names, Strdup("."))
		put(&z->nodes, Strdup("."), n)

process_zone(cstr name)
	zone *z = Get(zones, name)
	list *main_n = list_last((list*)Get(&z->nodes, "."))

	int count = vec_get_size(&z->node_names)

	# add reverse dns zones

	for(i, 0, count)
		cstr from = *(cstr *)v(&z->node_names, i)
		node *l = Get(&z->nodes, from)
		for_list(n, &l)
			cstr to = n->to
			node_type type = n->type
			if among(type, N_A, N_AAAA)
				cstr rev = reverse_zone(to, type)
				key_value *kv = kv(zones, rev, NULL)
				if !kv->v
					vec_push(zone_names, rev)
					NEW(kv->v, zone)
					zone *z = kv->v
					New(n, node, NULL, Format("%s.%s", from, name))
					n->type = N_PTR
					vec_push(&z->node_names, Strdup("."))
					put(&z->nodes, Strdup("."), n)
				 else
					Free(rev)

	# resolve cnames, add ns and mx records

	for(i, 0, count)
		cstr from = *(cstr *)v(&z->node_names, i)
		verbose("\nprocessing nodes from %s %s", name, from)
		node **l = (node **)&(KV(&z->nodes, from)->v)
		boolean is_multi = list_count_in_records(*l) > 1
		boolean is_main = 0, is_ns = 0, is_mx = 0
		if cstr_eq(from, ".")
			is_main = 1
		 eif !strncmp(from, "ns", 2) && isdigit(from[2])
			is_ns = 1
		 eif !strncmp(from, "mx", 2) && isdigit(from[2])
			is_mx = 1
		for_list(n, l, link0, link1)
			cstr to = n->to
			node_type type = n->type
			if is_ns
				verbose("adding ns record")
				main_n = list_last(main_n)
				New(n, node, NULL, Strdup(from))
				n->type = N_NS
				main_n->next = (list*)n
			 eif is_mx
				verbose("adding mx record")
				main_n = list_last(main_n)
				New(n, node, NULL, Strdup(from))
				n->type = N_MX
				main_n->next = (list*)n
			if type == N_CNAME && (is_main || is_ns || is_mx || is_multi || resolve_all)
#				warn("trying to resolve %s %s %s", name, from, to)
				new(resolved, vec, cstr, 8)
				resolve(z, name, to, resolved, 0)
				Free(n->to)
				 # this Free was causing a segfault, maybe the node is still needed.
				 # we could allocat the nodes in a block somehow so it can be freed together
				 # - if this was going to be used in a library or something.
#				warn("list dump before rm")
#				node *start = Get(&z->nodes, from)
#				for_list(n1, &start)
#					warn("%p %s %s %s", n1, name, from, n1->to)
				for_list_rm()
#				warn("list dump before add")
#				start = Get(&z->nodes, from)
#				for_list(n1, &start)
#					warn("%p %s %s %s", n1, name, from, n1->to)
#				warn("list dump done")
				*link0 = n->next
				link1 = link0
#				warn("*link0 = *link1 = %p", *link1)
				Free(n)
				 # this Free was causing problems before..?
				for_vec(j, resolved, cstr)
					cstr to = *j
#					warn("resolved to %s", to)
					New(n, node, *link0, to)
					*link0 = n
					link0 = &n->next
					link1 = link0
#				warn("after, *link0 = *link1 = %p", *link1)
#				warn("list dump after add")
#				start = Get(&z->nodes, from)
#				for_list(n1, &start)
#					warn("%p %s %s %s", n1, name, from, n1->to)
#				warn("list dump done")

# remove/fix this, it won't work generally because names of link0 and link1 are not known
# maybe I should use a struct* for the list iteration state
def for_list_rm()
#	warn("link0 %p link1 %p", link0, link1)
#	warn("*link0 %p *link1 %p", *link0, *link1)
	*link0 = *link1
#	warn("link0 %p link1 %p", link0, link1)
#	warn("*link0 %p *link1 %p", *link0, *link1)

int resolve_max_depth = 32

# TODO cache resolve targets

resolve(zone *z_from, cstr name, cstr to, vec *resolved, int depth)
	verbose("resolve: %s %s", name, to)
	if depth >= resolve_max_depth
		error("cannot resolve %s %s within %d steps - name loop?", name, to, resolve_max_depth)
	zone *z_to = NULL
	boolean is_main = !strcmp(to, ".")
	cstr z_to_name = strchr(to, '.')
	int need_put_back_dot = 0
	if is_main
		z_to = z_from
	 else
		z_to = get(zones, to)
		if z_to
			z_to_name = to
			to = Strdup(".")
		 eif z_to_name
			*z_to_name = '\0'
			++z_to_name
			need_put_back_dot = 1
			z_to = get(zones, z_to_name)
		 else
			z_to = z_from
#	warn("z_to %s %p", z_to_name, z_to)
#	warn("to %s", to)
	node *l = NULL
	if z_to
		l = get(&z_to->nodes, to)
	if need_put_back_dot
		z_to_name[-1] = '.'
	if !l
		cstr to_full
		if z_to_name
			to_full = Strdup(to)
		 else
			to_full = Format("%s.%s", to, name)
		cstr_set_add(resolved, to_full)
	 else
		for_list(n, &l)
			if n->type == N_CNAME
				verbose("recursive resolve")
				resolve(z_to, z_to_name, ((node *)n)->to, resolved, depth+1)
			 eif among(n->type, N_A, N_AAAA)
				cstr_set_add(resolved, Strdup(n->to))
	verbose("resolved to: %s (depth %d)", joinv(" ", resolved), depth)

output_zone(cstr name)
	boolean has_ns = 0
	boolean has_mx = 0

	zone *z = Get(zones, name)
	cstr zone_from = cstr_cat(name, ".")

	cstr file = path_cat(bind_dir, name)
	cstr file_new = cstr_cat(bind_dir, path__sep_cstr, ".", name, ".new")

	Fsayf(named_conf_s, named_conf_template, name, bind_dir, name)

	F_out(file_new)
		sf(header, serial)
		for_vec(i, &z->node_names, cstr)
			cstr from = *i
			node *l = Get(&z->nodes, *i)
			cstr from_add_dot = "", to_add_dot = ""
			if !strcmp(from, ".") || !strcmp(from, name)
				from = zone_from
#			 eif strchr(from, '.')
#				from_add_dot = "."
			for_list(n, &l)
				cstr to = n->to
				node_type type = n->type
				if !strcmp(to, ".")
					to = zone_from
				 eif among(type, N_CNAME, N_PTR) && strchr(to, '.')
					to_add_dot = "."
				if among(type, N_A, N_AAAA, N_CNAME, N_PTR)
					sf("%s%s\tIN\t%s\t%s%s", from, from_add_dot, node_type_cstr(type), to, to_add_dot)
				 eif type == N_NS
					sf("%s%s\tIN\tNS\t%s%s", from, from_add_dot, to, to_add_dot)
					has_ns = 1
				 eif type == N_MX
					sf("%s%s\tMX\t%d\t%s%s", from, from_add_dot, mx_priority, to, to_add_dot)
					has_mx = 1

		if !has_ns && default_ns
			pf(default_ns)
		if !has_ns && default_mx
			pf(default_mx)

	Systeml("freens-post", name, file_new)

	if !exists(file) || file_cmp(file, file_new)
		Rename(file_new, file)
		changed = 1
	 else
		Remove(file_new)

	Free(file)
	Free(file_new)
	Free(zone_from)

cstr reverse_zone(cstr to, node_type type)
	new(b, buffer, 512)
	cstr work = Strdup(to)

	if type == N_A
		repeat
			char *i = strrchr(work, '.')
			if i
				*i++ = '\0'
			 else
				i = work
			buffer_cat_cstr(b, i)
			buffer_cat_char(b, '.')
			i -= 2
			if i < work
				break
		buffer_cat_cstr(b, "in-addr.arpa")

	 eif type == N_AAAA
		# this is excessively complicated!
		int sections = 0
		char *i = work
		char *j
		boolean has_middle = 0
		repeat
			j = strchr(i, ':')
			if !j
				if strlen(i)
					++sections
				break
			if j > i
				++sections
			i = j+1
			if *i == ':'
				has_middle = 1
				++i
		int middle_size = 0
		if has_middle
			middle_size = 8 - sections
		repeat
			boolean is_middle = 0
			char *e = work + strlen(work)-1
			i = strrchr(work, ':')
			if i
				if i > work && i[-1] == ':'
					is_middle = 1
					i[-1] = '\0'
				*i++ = '\0'
			 else
				i = work
			if i <= e
				int c = 4
				while e >= i && c--
					buffer_cat_char(b, *e--)
					buffer_cat_char(b, '.')
				while c--
					buffer_cat_cstr(b, "0.")
			if is_middle
				repeat(4*middle_size)
					buffer_cat_cstr(b, "0.")
			if i == work
				break
		buffer_cat_cstr(b, "ip6.arpa")

	 else
		error("reverse_zone: bad node type")

	Free(work)
	return buffer_to_cstr(b)

remove_old_zones()
	vec *v = ls(bind_dir)
	for_vec(i, v, cstr)
		cstr name = *i
		zone *z = get(zones, name)
		if !z
			cstr path = path_cat(bind_dir, name)
			Remove(path)
			Free(path)
			changed = 1

signal_named()
	System(bind_reload)
