diff options
| author | Yawning Angel <yawning@schwanenlied.me> | 2024-02-16 18:58:41 +0900 |
|---|---|---|
| committer | Yawning Angel <yawning@schwanenlied.me> | 2024-02-24 14:05:15 +0900 |
| commit | 874d6ccb6078078e7554bc40acebbb86e6e8ee7c (patch) | |
| tree | 49b6bc192c372ea799ab9a973ace3a7fae93004d | |
| parent | db3279e7da4d422005fa9224a9fb8c2301d88837 (diff) | |
core/container/avl: Initial import
| -rw-r--r-- | core/container/avl/avl.odin | 678 | ||||
| -rw-r--r-- | examples/all/all_main.odin | 2 | ||||
| -rw-r--r-- | tests/core/container/test_core_avl.odin | 161 | ||||
| -rw-r--r-- | tests/core/container/test_core_container.odin | 1 |
4 files changed, 842 insertions, 0 deletions
diff --git a/core/container/avl/avl.odin b/core/container/avl/avl.odin new file mode 100644 index 000000000..eecc1b756 --- /dev/null +++ b/core/container/avl/avl.odin @@ -0,0 +1,678 @@ +/* +package avl implements an AVL tree. + +The implementation is non-intrusive, and non-recursive. +*/ +package container_avl + +import "base:intrinsics" +import "base:runtime" +import "core:slice" + +_ :: intrinsics +_ :: runtime + +// Originally based on the CC0 implementation by Eric Biggers +// See: https://github.com/ebiggers/avl_tree/ + +// Direction specifies the traversal direction for a tree iterator. +Direction :: enum i8 { + // Backward is the in-order backwards direction. + Backward = -1, + // Forward is the in-order forwards direction. + Forward = 1, +} + +// Ordering specifies order when inserting/finding values into the tree. +Ordering :: slice.Ordering + +// Tree is an AVL tree. +Tree :: struct($Value: typeid) { + // user_data is a parameter that will be passed to the on_remove + // callback. + user_data: rawptr, + // on_remove is an optional callback that can be called immediately + // after a node is removed from the tree. + on_remove: proc(value: Value, user_data: rawptr), + + _root: ^Node(Value), + _node_allocator: runtime.Allocator, + _cmp_fn: proc(a, b: Value) -> Ordering, + _size: int, +} + +// Node is an AVL tree node. +// +// WARNING: It is unsafe to mutate value if the node is part of a tree +// if doing so will alter the Node's sort position relative to other +// elements in the tree. +Node :: struct($Value: typeid) { + value: Value, + + _parent: ^Node(Value), + _left: ^Node(Value), + _right: ^Node(Value), + _balance: i8, +} + +// Iterator is a tree iterator. +// +// WARNING: It is unsafe to modify the tree while iterating, except via +// the iterator_remove method. +Iterator :: struct($Value: typeid) { + _tree: ^Tree(Value), + _cur: ^Node(Value), + _next: ^Node(Value), + _direction: Direction, + _called_next: bool, +} + +// init initializes a tree. +init :: proc { + init_ordered, + init_cmp, +} + +// init_cmp initializes a tree. +init_cmp :: proc( + t: ^$T/Tree($Value), + cmp_fn: proc(a, b: Value) -> Ordering, + node_allocator := context.allocator, +) { + t._root = nil + t._node_allocator = node_allocator + t._cmp_fn = cmp_fn + t._size = 0 +} + +// init_ordered initializes a tree containing ordered items, with +// a comparison function that results in an ascending order sort. +init_ordered :: proc( + t: ^$T/Tree($Value), + node_allocator := context.allocator, +) where intrinsics.type_is_ordered_numeric(Value) { + init_cmp(t, slice.cmp_proc(Value), node_allocator) +} + +// destroy de-initializes a tree. +destroy :: proc(t: ^$T/Tree($Value), call_on_remove: bool = true) { + iter := iterator(t, Direction.Forward) + for _ in iterator_next(&iter) { + iterator_remove(&iter, call_on_remove) + } +} + +// len returns the number of elements in the tree. +len :: proc "contextless" (t: ^$T/Tree($Value)) -> int { + return t._size +} + +// first returns the first node in the tree (in-order) or nil iff +// the tree is empty. +first :: proc "contextless" (t: ^$T/Tree($Value)) -> ^Node(Value) { + return tree_first_or_last_in_order(t, Direction.Backward) +} + +// last returns the last element in the tree (in-order) or nil iff +// the tree is empty. +last :: proc "contextless" (t: ^$T/Tree($Value)) -> ^Node(Value) { + return tree_first_or_last_in_order(t, Direction.Forward) +} + +// find finds the value in the tree, and returns the corresponding +// node or nil iff the value is not present. +find :: proc(t: ^$T/Tree($Value), value: Value) -> ^Node(Value) { + cur := t._root + descend_loop: for cur != nil { + switch t._cmp_fn(value, cur.value) { + case .Less: + cur = cur._left + case .Greater: + cur = cur._right + case .Equal: + break descend_loop + } + } + + return cur +} + +// find_or_insert attempts to insert the value into the tree, and returns +// the node, a boolean indicating if the value was inserted, and the +// node allocator error if relevant. If the value is already +// present, the existing node is returned un-altered. +find_or_insert :: proc( + t: ^$T/Tree($Value), + value: Value, +) -> ( + n: ^Node(Value), + inserted: bool, + err: runtime.Allocator_Error, +) { + n_ptr := &t._root + for n_ptr^ != nil { + n = n_ptr^ + switch t._cmp_fn(value, n.value) { + case .Less: + n_ptr = &n._left + case .Greater: + n_ptr = &n._right + case .Equal: + return + } + } + + parent := n + n = new(Node(Value), t._node_allocator) or_return + n.value = value + n._parent = parent + n_ptr^ = n + tree_rebalance_after_insert(t, n) + + t._size += 1 + inserted = true + + return +} + +// remove removes a node or value from the tree, and returns true iff the +// removal was successful. While the node's value will be left intact, +// the node itself will be freed via the tree's node allocator. +remove :: proc { + remove_value, + remove_node, +} + +// remove_value removes a value from the tree, and returns true iff the +// removal was successful. While the node's value will be left intact, +// the node itself will be freed via the tree's node allocator. +remove_value :: proc(t: ^$T/Tree($Value), value: Value, call_on_remove: bool = true) -> bool { + n := find(t, value) + if n == nil { + return false + } + return remove_node(t, n, call_on_remove) +} + +// remove_node removes a node from the tree, and returns true iff the +// removal was successful. While the node's value will be left intact, +// the node itself will be freed via the tree's node allocator. +remove_node :: proc(t: ^$T/Tree($Value), node: ^Node(Value), call_on_remove: bool = true) -> bool { + if node._parent == node || (node._parent == nil && t._root != node) { + return false + } + defer { + if call_on_remove && t.on_remove != nil { + t.on_remove(node.value, t.user_data) + } + free(node, t._node_allocator) + } + + parent: ^Node(Value) + left_deleted: bool + + t._size -= 1 + if node._left != nil && node._right != nil { + parent, left_deleted = tree_swap_with_successor(t, node) + } else { + child := node._left + if child == nil { + child = node._right + } + parent = node._parent + if parent != nil { + if node == parent._left { + parent._left = child + left_deleted = true + } else { + parent._right = child + left_deleted = false + } + if child != nil { + child._parent = parent + } + } else { + if child != nil { + child._parent = parent + } + t._root = child + node_reset(node) + return true + } + } + + for { + if left_deleted { + parent = tree_handle_subtree_shrink(t, parent, +1, &left_deleted) + } else { + parent = tree_handle_subtree_shrink(t, parent, -1, &left_deleted) + } + if parent == nil { + break + } + } + node_reset(node) + + return true +} + +// iterator returns a tree iterator in the specified direction. +iterator :: proc "contextless" (t: ^$T/Tree($Value), direction: Direction) -> Iterator(Value) { + it: Iterator(Value) + it._tree = transmute(^Tree(Value))t + it._direction = direction + + iterator_first(&it) + + return it +} + +// iterator_from_pos returns a tree iterator in the specified direction, +// spanning the range [pos, last] (inclusive). +iterator_from_pos :: proc "contextless" ( + t: ^$T/Tree($Value), + pos: ^Node(Value), + direction: Direction, +) -> Iterator(Value) { + it: Iterator(Value) + it._tree = transmute(^Tree(Value))t + it._direction = direction + it._next = nil + it._called_next = false + + if it._cur = pos; pos != nil { + it._next = node_next_or_prev_in_order(it._cur, it._direction) + } + + return it +} + +// iterator_get returns the node currently pointed to by the iterator, +// or nil iff the node has been removed, the tree is empty, or the end +// of the tree has been reached. +iterator_get :: proc "contextless" (it: ^$I/Iterator($Value)) -> ^Node(Value) { + return it._cur +} + +// iterator_remove removes the node currently pointed to by the iterator, +// and returns true iff the removal was successful. Semantics are the +// same as the Tree remove. +iterator_remove :: proc(it: ^$I/Iterator($Value), call_on_remove: bool = true) -> bool { + if it._cur == nil { + return false + } + + ok := remove_node(it._tree, it._cur, call_on_remove) + if ok { + it._cur = nil + } + + return ok +} + +// iterator_next advances the iterator and returns the (node, true) or +// or (nil, false) iff the end of the tree has been reached. +// +// Note: The first call to iterator_next will return the first node instead +// of advancing the iterator. +iterator_next :: proc "contextless" (it: ^$I/Iterator($Value)) -> (^Node(Value), bool) { + // This check is needed so that the first element gets returned from + // a brand-new iterator, and so that the somewhat contrived case where + // iterator_remove is called before the first call to iterator_next + // returns the correct value. + if !it._called_next { + it._called_next = true + + // There can be the contrived case where iterator_remove is + // called before ever calling iterator_next, which needs to be + // handled as an actual call to next. + // + // If this happens it._cur will be nil, so only return the + // first value, if it._cur is valid. + if it._cur != nil { + return it._cur, true + } + } + + if it._next == nil { + return nil, false + } + + it._cur = it._next + it._next = node_next_or_prev_in_order(it._cur, it._direction) + + return it._cur, true +} + +@(private) +tree_first_or_last_in_order :: proc "contextless" ( + t: ^$T/Tree($Value), + direction: Direction, +) -> ^Node(Value) { + first, sign := t._root, i8(direction) + if first != nil { + for { + tmp := node_get_child(first, +sign) + if tmp == nil { + break + } + first = tmp + } + } + + return first +} + +@(private) +tree_replace_child :: proc "contextless" ( + t: ^$T/Tree($Value), + parent, old_child, new_child: ^Node(Value), +) { + if parent != nil { + if old_child == parent._left { + parent._left = new_child + } else { + parent._right = new_child + } + } else { + t._root = new_child + } +} + +@(private) +tree_rotate :: proc "contextless" (t: ^$T/Tree($Value), a: ^Node(Value), sign: i8) { + b := node_get_child(a, -sign) + e := node_get_child(b, +sign) + p := a._parent + + node_set_child(a, -sign, e) + a._parent = b + + node_set_child(b, +sign, a) + b._parent = p + + if e != nil { + e._parent = a + } + + tree_replace_child(t, p, a, b) +} + +@(private) +tree_double_rotate :: proc "contextless" ( + t: ^$T/Tree($Value), + b, a: ^Node(Value), + sign: i8, +) -> ^Node(Value) { + e := node_get_child(b, +sign) + f := node_get_child(e, -sign) + g := node_get_child(e, +sign) + p := a._parent + e_bal := e._balance + + node_set_child(a, -sign, g) + a_bal := -e_bal + if sign * e_bal >= 0 { + a_bal = 0 + } + node_set_parent_balance(a, e, a_bal) + + node_set_child(b, +sign, f) + b_bal := -e_bal + if sign * e_bal <= 0 { + b_bal = 0 + } + node_set_parent_balance(b, e, b_bal) + + node_set_child(e, +sign, a) + node_set_child(e, -sign, b) + node_set_parent_balance(e, p, 0) + + if g != nil { + g._parent = a + } + + if f != nil { + f._parent = b + } + + tree_replace_child(t, p, a, e) + + return e +} + +@(private) +tree_handle_subtree_growth :: proc "contextless" ( + t: ^$T/Tree($Value), + node, parent: ^Node(Value), + sign: i8, +) -> bool { + old_balance_factor := parent._balance + if old_balance_factor == 0 { + node_adjust_balance_factor(parent, sign) + return false + } + + new_balance_factor := old_balance_factor + sign + if new_balance_factor == 0 { + node_adjust_balance_factor(parent, sign) + return true + } + + if sign * node._balance > 0 { + tree_rotate(t, parent, -sign) + node_adjust_balance_factor(parent, -sign) + node_adjust_balance_factor(node, -sign) + } else { + tree_double_rotate(t, node, parent, -sign) + } + + return true +} + +@(private) +tree_rebalance_after_insert :: proc "contextless" (t: ^$T/Tree($Value), inserted: ^Node(Value)) { + node, parent := inserted, inserted._parent + switch { + case parent == nil: + return + case node == parent._left: + node_adjust_balance_factor(parent, -1) + case: + node_adjust_balance_factor(parent, +1) + } + + if parent._balance == 0 { + return + } + + for done := false; !done; { + node = parent + if parent = node._parent; parent == nil { + return + } + + if node == parent._left { + done = tree_handle_subtree_growth(t, node, parent, -1) + } else { + done = tree_handle_subtree_growth(t, node, parent, +1) + } + } +} + +@(private) +tree_swap_with_successor :: proc "contextless" ( + t: ^$T/Tree($Value), + x: ^Node(Value), +) -> ( + ^Node(Value), + bool, +) { + ret: ^Node(Value) + left_deleted: bool + + y := x._right + if y._left == nil { + ret = y + } else { + q: ^Node(Value) + + for { + q = y + if y = y._left; y._left == nil { + break + } + } + + if q._left = y._right; q._left != nil { + q._left._parent = q + } + y._right = x._right + x._right._parent = y + ret = q + left_deleted = true + } + + y._left = x._left + x._left._parent = y + + y._parent = x._parent + y._balance = x._balance + + tree_replace_child(t, x._parent, x, y) + + return ret, left_deleted +} + +@(private) +tree_handle_subtree_shrink :: proc "contextless" ( + t: ^$T/Tree($Value), + parent: ^Node(Value), + sign: i8, + left_deleted: ^bool, +) -> ^Node(Value) { + old_balance_factor := parent._balance + if old_balance_factor == 0 { + node_adjust_balance_factor(parent, sign) + return nil + } + + node: ^Node(Value) + new_balance_factor := old_balance_factor + sign + if new_balance_factor == 0 { + node_adjust_balance_factor(parent, sign) + node = parent + } else { + node = node_get_child(parent, sign) + if sign * node._balance >= 0 { + tree_rotate(t, parent, -sign) + if node._balance == 0 { + node_adjust_balance_factor(node, -sign) + return nil + } + node_adjust_balance_factor(parent, -sign) + node_adjust_balance_factor(node, -sign) + } else { + node = tree_double_rotate(t, node, parent, -sign) + } + } + + parent := parent + if parent = node._parent; parent != nil { + left_deleted^ = node == parent._left + } + return parent +} + +@(private) +node_reset :: proc "contextless" (n: ^Node($Value)) { + // Mostly pointless as n will be deleted after this is called, but + // attempt to be able to catch cases of n not being in the tree. + n._parent = n + n._left = nil + n._right = nil + n._balance = 0 +} + +@(private) +node_set_parent_balance :: #force_inline proc "contextless" ( + n, parent: ^Node($Value), + balance: i8, +) { + n._parent = parent + n._balance = balance +} + +@(private) +node_get_child :: #force_inline proc "contextless" (n: ^Node($Value), sign: i8) -> ^Node(Value) { + if sign < 0 { + return n._left + } + return n._right +} + +@(private) +node_next_or_prev_in_order :: proc "contextless" ( + n: ^Node($Value), + direction: Direction, +) -> ^Node(Value) { + next, tmp: ^Node(Value) + sign := i8(direction) + + if next = node_get_child(n, +sign); next != nil { + for { + tmp = node_get_child(next, -sign) + if tmp == nil { + break + } + next = tmp + } + } else { + tmp, next = n, n._parent + for next != nil && tmp == node_get_child(next, +sign) { + tmp, next = next, next._parent + } + } + return next +} + +@(private) +node_set_child :: #force_inline proc "contextless" ( + n: ^Node($Value), + sign: i8, + child: ^Node(Value), +) { + if sign < 0 { + n._left = child + } else { + n._right = child + } +} + +@(private) +node_adjust_balance_factor :: #force_inline proc "contextless" (n: ^Node($Value), amount: i8) { + n._balance += amount +} + +@(private) +iterator_first :: proc "contextless" (it: ^Iterator($Value)) { + // This is private because behavior when the user manually calls + // iterator_first followed by iterator_next is unintuitive, since + // the first call to iterator_next MUST return the first node + // instead of advancing so that `for node in iterator_next(&next)` + // works as expected. + + switch it._direction { + case .Forward: + it._cur = tree_first_or_last_in_order(it._tree, .Backward) + case .Backward: + it._cur = tree_first_or_last_in_order(it._tree, .Forward) + } + + it._next = nil + it._called_next = false + + if it._cur != nil { + it._next = node_next_or_prev_in_order(it._cur, it._direction) + } +} diff --git a/examples/all/all_main.odin b/examples/all/all_main.odin index 8f2eebc8f..fff344b22 100644 --- a/examples/all/all_main.odin +++ b/examples/all/all_main.odin @@ -14,6 +14,7 @@ import shoco "core:compress/shoco" import gzip "core:compress/gzip" import zlib "core:compress/zlib" +import avl "core:container/avl" import bit_array "core:container/bit_array" import priority_queue "core:container/priority_queue" import queue "core:container/queue" @@ -131,6 +132,7 @@ _ :: compress _ :: shoco _ :: gzip _ :: zlib +_ :: avl _ :: bit_array _ :: priority_queue _ :: queue diff --git a/tests/core/container/test_core_avl.odin b/tests/core/container/test_core_avl.odin new file mode 100644 index 000000000..f6343c5ea --- /dev/null +++ b/tests/core/container/test_core_avl.odin @@ -0,0 +1,161 @@ +package test_core_container + +import "core:container/avl" +import "core:math/rand" +import "core:slice" +import "core:testing" + +import tc "tests:common" + +@(test) +test_avl :: proc(t: ^testing.T) { + tc.log(t, "Testing avl") + + // Initialization. + tree: avl.Tree(int) + avl.init(&tree, slice.cmp_proc(int)) + tc.expect(t, avl.len(&tree) == 0, "empty: len should be 0") + tc.expect(t, avl.first(&tree) == nil, "empty: first should be nil") + tc.expect(t, avl.last(&tree) == nil, "empty: last should be nil") + + iter := avl.iterator(&tree, avl.Direction.Forward) + tc.expect(t, avl.iterator_get(&iter) == nil, "empty/iterator: first node should be nil") + + // Test insertion. + NR_INSERTS :: 32 + 1 // Ensure at least 1 collision. + inserted_map := make(map[int]^avl.Node(int)) + for i := 0; i < NR_INSERTS; i += 1 { + v := int(rand.uint32() & 0x1f) + existing_node, in_map := inserted_map[v] + + n, ok, _ := avl.find_or_insert(&tree, v) + tc.expect(t, in_map != ok, "insert: ok should match inverse of map lookup") + if ok { + inserted_map[v] = n + } else { + tc.expect(t, existing_node == n, "insert: expecting existing node") + } + } + nrEntries := len(inserted_map) + tc.expect(t, avl.len(&tree) == nrEntries, "insert: len after") + tree_validate(t, &tree) + + // Ensure that all entries can be found. + for k, v in inserted_map { + tc.expect(t, v == avl.find(&tree, k), "Find(): Node") + tc.expect(t, k == v.value, "Find(): Node value") + } + + // Test the forward/backward iterators. + inserted_values: [dynamic]int + for k in inserted_map { + append(&inserted_values, k) + } + slice.sort(inserted_values[:]) + + iter = avl.iterator(&tree, avl.Direction.Forward) + visited: int + for node in avl.iterator_next(&iter) { + v, idx := node.value, visited + tc.expect(t, inserted_values[idx] == v, "iterator/forward: value") + tc.expect(t, node == avl.iterator_get(&iter), "iterator/forward: get") + visited += 1 + } + tc.expect(t, visited == nrEntries, "iterator/forward: visited") + + slice.reverse(inserted_values[:]) + iter = avl.iterator(&tree, avl.Direction.Backward) + visited = 0 + for node in avl.iterator_next(&iter) { + v, idx := node.value, visited + tc.expect(t, inserted_values[idx] == v, "iterator/backward: value") + visited += 1 + } + tc.expect(t, visited == nrEntries, "iterator/backward: visited") + + // Test removal. + rand.shuffle(inserted_values[:]) + for v, i in inserted_values { + node := avl.find(&tree, v) + tc.expect(t, node != nil, "remove: find (pre)") + + ok := avl.remove(&tree, v) + tc.expect(t, ok, "remove: succeeds") + tc.expect(t, nrEntries - (i + 1) == avl.len(&tree), "remove: len (post)") + tree_validate(t, &tree) + + tc.expect(t, nil == avl.find(&tree, v), "remove: find (post") + } + tc.expect(t, avl.len(&tree) == 0, "remove: len should be 0") + tc.expect(t, avl.first(&tree) == nil, "remove: first should be nil") + tc.expect(t, avl.last(&tree) == nil, "remove: last should be nil") + + // Refill the tree. + for v in inserted_values { + avl.find_or_insert(&tree, v) + } + + // Test that removing the node doesn't break the iterator. + iter = avl.iterator(&tree, avl.Direction.Forward) + if node := avl.iterator_get(&iter); node != nil { + v := node.value + + ok := avl.iterator_remove(&iter) + tc.expect(t, ok, "iterator/remove: success") + + ok = avl.iterator_remove(&iter) + tc.expect(t, !ok, "iterator/remove: redundant removes should fail") + + tc.expect(t, avl.find(&tree, v) == nil, "iterator/remove: node should be gone") + tc.expect(t, avl.iterator_get(&iter) == nil, "iterator/remove: get should return nil") + + // Ensure that iterator_next still works. + node, ok = avl.iterator_next(&iter) + tc.expect(t, ok == (avl.len(&tree) > 0), "iterator/remove: next should return false") + tc.expect(t, node == avl.first(&tree), "iterator/remove: next should return first") + + tree_validate(t, &tree) + } + tc.expect(t, avl.len(&tree) == nrEntries - 1, "iterator/remove: len should drop by 1") + + avl.destroy(&tree) + tc.expect(t, avl.len(&tree) == 0, "destroy: len should be 0") +} + +@(private) +tree_validate :: proc(t: ^testing.T, tree: ^avl.Tree($Value)) { + tree_check_invariants(t, tree, tree._root, nil) +} + +@(private) +tree_check_invariants :: proc( + t: ^testing.T, + tree: ^avl.Tree($Value), + node, parent: ^avl.Node(Value), +) -> int { + if node == nil { + return 0 + } + + // Validate the parent pointer. + tc.expect(t, parent == node._parent, "invalid parent pointer") + + // Validate that the balance factor is -1, 0, 1. + tc.expect( + t, + node._balance == -1 || node._balance == 0 || node._balance == 1, + "invalid balance factor", + ) + + // Recursively derive the height of the left and right sub-trees. + l_height := tree_check_invariants(t, tree, node._left, node) + r_height := tree_check_invariants(t, tree, node._right, node) + + // Validate the AVL invariant and the balance factor. + tc.expect(t, int(node._balance) == r_height - l_height, "AVL balance factor invariant violated") + if l_height > r_height { + return l_height + 1 + } + + return r_height + 1 +} diff --git a/tests/core/container/test_core_container.odin b/tests/core/container/test_core_container.odin index 9065bed2c..f816a6bcb 100644 --- a/tests/core/container/test_core_container.odin +++ b/tests/core/container/test_core_container.odin @@ -19,6 +19,7 @@ expect_equal :: proc(t: ^testing.T, the_slice, expected: []int, loc := #caller_l main :: proc() { t := testing.T{} + test_avl(&t) test_small_array(&t) tc.report(&t) |