diff --git a/include/mcl/container/hmap.hpp b/include/mcl/container/hmap.hpp index 1212f02..17e880f 100644 --- a/include/mcl/container/hmap.hpp +++ b/include/mcl/container/hmap.hpp @@ -5,6 +5,8 @@ #pragma once #include +#include +#include #include #include "mcl/assert.hpp" @@ -14,12 +16,12 @@ #include "mcl/stdint.hpp" #if defined(MCL_ARCHITECTURE_ARM64) -# include "arm_neon.h" +# include #endif namespace mcl { -template +template class hmap; namespace detail { @@ -73,20 +75,25 @@ struct meta_byte_group { return vmaxvq_u8(data) == 0xff; } + bool is_all_empty_or_tombstone() + { + return vminvq_u8(vandq_u8(data, vdupq_n_u8(0x80))) == 0x80; + } + uint8x16_t data; }; # define MCL_HMAP_MATCH_META_BYTE_GROUP(MATCH, ...) \ { \ - const uint64x2_t match_result = MATCH; \ + const uint64x2_t match_result{MATCH}; \ \ for (u64 match_result_v{match_result[0]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \ - const size_t match_index = std::countr_zero(match_result_v) / 8; \ + const size_t match_index{static_cast(std::countr_zero(match_result_v) / 8)}; \ __VA_ARGS__ \ } \ \ for (u64 match_result_v{match_result[1]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \ - const size_t match_index = 8 + std::countr_zero(match_result_v) / 8; \ + const size_t match_index{static_cast(8 + std::countr_zero(match_result_v) / 8)}; \ __VA_ARGS__ \ } \ } @@ -103,12 +110,17 @@ union slot_union { } // namespace detail -template +template class hmap_iterator { + using base_value_type = std::pair; + using slot_type = detail::slot_union; + public: + using key_type = KeyType; + using mapped_type = MappedType; using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = std::pair; + using value_type = std::conditional_t, base_value_type>; using pointer = value_type*; using const_pointer = const value_type*; using reference = value_type&; @@ -156,9 +168,7 @@ public: } private: - friend class hmap; - - using slot_type = detail::slot_union; + friend class hmap; hmap_iterator(detail::meta_byte* mb_ptr, slot_type* slot_ptr) : mb_ptr{mb_ptr}, slot_ptr{slot_ptr} @@ -184,41 +194,57 @@ private: } } - detail::meta_byte* mb_ptr = nullptr; - slot_type* slot_ptr = nullptr; + detail::meta_byte* mb_ptr{nullptr}; + slot_type* slot_ptr{nullptr}; }; -template +template, typename Pred = std::equal_to> class hmap { public: using key_type = KeyType; using mapped_type = MappedType; - using value_type = std::pair; - using iterator = hmap_iterator; + using hasher = Hash; + using key_equal = Pred; + using value_type = std::pair; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + using iterator = hmap_iterator; + using const_iterator = hmap_iterator; private: using slot_type = detail::slot_union; static_assert(!std::is_reference_v); static_assert(!std::is_reference_v); - static constexpr size_t group_size = 16; + static constexpr size_t group_size{16}; public: hmap() { initialize_members(1); } + hmap(const hmap&); + hmap(hmap&&); + hmap& operator=(const hmap&); + hmap& operator=(hmap&&); ~hmap() { - for (auto iter = begin(); iter != end(); ++iter) { - iter->~value_type(); - } + clear(); } + [[nodiscard]] bool empty() const noexcept { return full_slots == 0; } + size_type size() const noexcept { return full_slots; } + size_type max_size() const noexcept { return static_cast(std::numeric_limits::max()); } + iterator begin() { - iterator result{get_iterator_at(0)}; + iterator result{iterator_at(0)}; result.skip_empty_or_tombstone(); return result; } @@ -226,34 +252,26 @@ public: { return {}; } - - iterator find(const key_type& key) + const_iterator cbegin() const { - const size_t hash = std::hash{}(key); - const detail::meta_byte mb = detail::meta_byte_from_hash(hash); - - size_t group_index = detail::group_index_from_hash(hash, group_index_mask); - - while (true) { - detail::meta_byte_group g{mbs.get() + group_index * group_size}; - - MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { - const size_t item_index{group_index * group_size + match_index}; - - if (slots[item_index].value.first == key) [[likely]] { - return get_iterator_at(item_index); - } - }); - - if (g.is_any_empty()) [[likely]] { - return {}; - } - - group_index = (group_index + 1) & group_index_mask; - } + const_iterator result{const_iterator_at(0)}; + result.skip_empty_or_tombstone(); + return result; + } + const_iterator cend() const + { + return {}; + } + const_iterator begin() const + { + return cbegin(); + } + const_iterator end() const + { + return cend(); } - template + template std::pair try_emplace(K&& k, Args&&... args) { auto [item_index, item_found] = find_key_or_empty_slot(k); @@ -263,10 +281,10 @@ public: std::forward_as_tuple(std::forward(k)), std::forward_as_tuple(std::forward(args)...)); } - return {get_iterator_at(item_index), !item_found}; + return {iterator_at(item_index), !item_found}; } - template + template std::pair insert_or_assign(K&& k, V&& v) { auto [item_index, item_found] = find_key_or_empty_slot(k); @@ -277,27 +295,40 @@ public: std::forward(k), std::forward(v)); } - return {get_iterator_at(item_index), !item_found}; + return {iterator_at(item_index), !item_found}; } + void erase(const_iterator position) + { + if (position == cend()) { + return; + } + + const size_t item_index{static_cast(std::distance(mbs.get(), position.mb_ptr))}; + const size_t group_index{item_index / group_size}; + const detail::meta_byte_group g{mbs.get() + group_index * group_size}; + + erase_impl(item_index, std::move(g)); + } + void erase(iterator position) + { + if (position == end()) { + return; + } + + const size_t item_index{static_cast(std::distance(mbs.get(), position.mb_ptr))}; + const size_t group_index{item_index / group_size}; + const detail::meta_byte_group g{mbs.get() + group_index * group_size}; + + erase_impl(item_index, std::move(g)); + } template - mapped_type& operator[](K&& k) + size_t erase(const K& key) { - return try_emplace(std::forward(k)).first->second; - } + const size_t hash{hasher{}(key)}; + const detail::meta_byte mb{detail::meta_byte_from_hash(hash)}; -private: - iterator get_iterator_at(size_t item_index) - { - return {mbs.get() + item_index, slots.get() + item_index}; - } - - std::pair find_key_or_empty_slot(const key_type& key) - { - const size_t hash = std::hash{}(key); - const detail::meta_byte mb = detail::meta_byte_from_hash(hash); - - size_t group_index = detail::group_index_from_hash(hash, group_index_mask); + size_t group_index{detail::group_index_from_hash(hash, group_index_mask)}; while (true) { detail::meta_byte_group g{mbs.get() + group_index * group_size}; @@ -305,7 +336,140 @@ private: MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { const size_t item_index{group_index * group_size + match_index}; - if (slots[item_index].value.first == key) [[likely]] { + if (key_equal{}(slots[item_index].value.first, key)) [[likely]] { + erase_impl(item_index, std::move(g)); + return 1; + } + }); + + if (g.is_any_empty()) [[likely]] { + return 0; + } + + group_index = (group_index + 1) & group_index_mask; + } + } + + template + iterator find(const K& key) + { + const size_t hash{hasher{}(key)}; + const detail::meta_byte mb{detail::meta_byte_from_hash(hash)}; + + size_t group_index{detail::group_index_from_hash(hash, group_index_mask)}; + + while (true) { + detail::meta_byte_group g{mbs.get() + group_index * group_size}; + + MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { + const size_t item_index{group_index * group_size + match_index}; + + if (key_equal{}(slots[item_index].value.first, key)) [[likely]] { + return iterator_at(item_index); + } + }); + + if (g.is_any_empty()) [[likely]] { + return {}; + } + + group_index = (group_index + 1) & group_index_mask; + } + } + template + const_iterator find(const K& key) const + { + const size_t hash{hasher{}(key)}; + const detail::meta_byte mb{detail::meta_byte_from_hash(hash)}; + + size_t group_index{detail::group_index_from_hash(hash, group_index_mask)}; + + while (true) { + detail::meta_byte_group g{mbs.get() + group_index * group_size}; + + MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { + const size_t item_index{group_index * group_size + match_index}; + + if (key_equal{}(slots[item_index].value.first, key)) [[likely]] { + return const_iterator_at(item_index); + } + }); + + if (g.is_any_empty()) [[likely]] { + return {}; + } + + group_index = (group_index + 1) & group_index_mask; + } + } + template + bool contains(const K& key) const + { + return find(key) != end(); + } + template + size_t count(const K& key) const + { + return contains(key) ? 1 : 0; + } + + template + mapped_type& operator[](K&& k) + { + return try_emplace(std::forward(k)).first->second; + } + template + mapped_type& at(K&& k) + { + const auto iter{find(k)}; + if (iter == end()) { + throw std::out_of_range("hmap::at: key not found"); + } + return iter->second; + } + template + const mapped_type& at(K&& k) const + { + const auto iter{find(k)}; + if (iter == end()) { + throw std::out_of_range("hmap::at: key not found"); + } + return iter->second; + } + + void clear() + { + for (auto iter{begin()}; iter != end(); ++iter) { + iter->~value_type(); + } + + clear_metadata(); + } + +private: + iterator iterator_at(size_t item_index) + { + return {mbs.get() + item_index, slots.get() + item_index}; + } + const_iterator const_iterator_at(size_t item_index) + { + return {mbs.get() + item_index, slots.get() + item_index}; + } + + std::pair find_key_or_empty_slot(const key_type& key) + { + const size_t hash{hasher{}(key)}; + const detail::meta_byte mb{detail::meta_byte_from_hash(hash)}; + + size_t group_index{detail::group_index_from_hash(hash, group_index_mask)}; + + while (true) { + detail::meta_byte_group g{mbs.get() + group_index * group_size}; + + MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { + const size_t item_index{group_index * group_size + match_index}; + + if (key_equal{}(slots[item_index].value.first, key)) [[likely]] { return {item_index, true}; } }); @@ -324,7 +488,7 @@ private: grow_and_rehash(); } - size_t group_index = detail::group_index_from_hash(hash, group_index_mask); + size_t group_index{detail::group_index_from_hash(hash, group_index_mask)}; while (true) { detail::meta_byte_group g{mbs.get() + group_index * group_size}; @@ -335,6 +499,7 @@ private: if (mbs[item_index] == detail::meta_byte::empty) [[likely]] { --empty_slots; } + ++full_slots; mbs[item_index] = detail::meta_byte_from_hash(hash); @@ -345,25 +510,38 @@ private: } } + void erase_impl(size_t item_index, detail::meta_byte_group&& g) + { + slots[item_index].value->~value_type(); + + --full_slots; + if (g.is_any_empty()) { + mbs[item_index] = detail::meta_byte::empty; + ++empty_slots; + } else { + mbs[item_index] = detail::meta_byte::tombstone; + } + } + void grow_and_rehash() { - const size_t new_group_count = 2 * (group_index_mask + 1); + const size_t new_group_count{2 * (group_index_mask + 1)}; pow2_resize(new_group_count); } void pow2_resize(size_t new_group_count) { - auto iter = begin(); + auto iter{begin()}; - const auto old_mbs = std::move(mbs); - const auto old_slots = std::move(slots); + const auto old_mbs{std::move(mbs)}; + const auto old_slots{std::move(slots)}; initialize_members(new_group_count); for (; iter != iterator{}; ++iter) { - const size_t hash = std::hash{}(iter->first); - const size_t item_index = find_empty_slot_to_insert(hash); + const size_t hash{hasher{}(iter->first)}; + const size_t item_index{find_empty_slot_to_insert(hash)}; new (&slots[item_index].value) value_type(std::move(iter.slot_ptr->value)); iter.slot_ptr->value.~value_type(); @@ -375,16 +553,26 @@ private: // DEBUG_ASSERT(group_count != 0 && std::ispow2(group_count)); group_index_mask = group_count - 1; - empty_slots = group_count * group_size * 7 / 8; mbs = std::unique_ptr{new (std::align_val_t(group_size)) detail::meta_byte[group_count * group_size + 1]}; slots = std::unique_ptr{new slot_type[group_count * group_size]}; + clear_metadata(); + } + + void clear_metadata() + { + const size_t group_count{group_index_mask + 1}; + + empty_slots = group_count * group_size * 7 / 8; + full_slots = 0; + std::memset(mbs.get(), static_cast(detail::meta_byte::empty), group_count * group_size); mbs[group_count * group_size] = detail::meta_byte::end_sentinel; } size_t group_index_mask; size_t empty_slots; + size_t full_slots; std::unique_ptr mbs; std::unique_ptr slots; }; diff --git a/tests/hmap.cpp b/tests/hmap.cpp index 62989a1..5518574 100644 --- a/tests/hmap.cpp +++ b/tests/hmap.cpp @@ -15,8 +15,11 @@ TEST_CASE("mcl::hmap", "[hmap]") constexpr int count = 100000; + REQUIRE(double_map.empty()); + for (int i = 0; i < count; ++i) { double_map[i] = i * 2; + REQUIRE(double_map.size() == i + 1); } for (int i = 0; i < count; ++i) { @@ -36,4 +39,12 @@ TEST_CASE("mcl::hmap", "[hmap]") (void)k; REQUIRE(v == 1); } + + REQUIRE(!double_map.empty()); + double_map.clear(); + REQUIRE(double_map.empty()); + + for (auto [k, v] : double_map) { + REQUIRE(false); + } }