mcl: hmap: Add more member functions

This commit is contained in:
Merry 2022-05-01 14:24:21 +01:00
parent 4998335a5b
commit 91e3073ad3
2 changed files with 271 additions and 72 deletions

View File

@ -5,6 +5,8 @@
#pragma once #pragma once
#include <bit> #include <bit>
#include <cstddef>
#include <functional>
#include <utility> #include <utility>
#include "mcl/assert.hpp" #include "mcl/assert.hpp"
@ -14,12 +16,12 @@
#include "mcl/stdint.hpp" #include "mcl/stdint.hpp"
#if defined(MCL_ARCHITECTURE_ARM64) #if defined(MCL_ARCHITECTURE_ARM64)
# include "arm_neon.h" # include <arm_neon.h>
#endif #endif
namespace mcl { namespace mcl {
template<typename K, typename T> template<typename KeyType, typename MappedType, typename Hash, typename Pred>
class hmap; class hmap;
namespace detail { namespace detail {
@ -73,20 +75,25 @@ struct meta_byte_group {
return vmaxvq_u8(data) == 0xff; return vmaxvq_u8(data) == 0xff;
} }
bool is_all_empty_or_tombstone()
{
return vminvq_u8(vandq_u8(data, vdupq_n_u8(0x80))) == 0x80;
}
uint8x16_t data; uint8x16_t data;
}; };
# define MCL_HMAP_MATCH_META_BYTE_GROUP(MATCH, ...) \ # define MCL_HMAP_MATCH_META_BYTE_GROUP(MATCH, ...) \
{ \ { \
const uint64x2_t match_result = MATCH; \ const uint64x2_t match_result{MATCH}; \
\ \
for (u64 match_result_v{match_result[0]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \ for (u64 match_result_v{match_result[0]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \
const size_t match_index = std::countr_zero(match_result_v) / 8; \ const size_t match_index{static_cast<size_t>(std::countr_zero(match_result_v) / 8)}; \
__VA_ARGS__ \ __VA_ARGS__ \
} \ } \
\ \
for (u64 match_result_v{match_result[1]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \ for (u64 match_result_v{match_result[1]}; match_result_v != 0; match_result_v &= match_result_v - 1) { \
const size_t match_index = 8 + std::countr_zero(match_result_v) / 8; \ const size_t match_index{static_cast<size_t>(8 + std::countr_zero(match_result_v) / 8)}; \
__VA_ARGS__ \ __VA_ARGS__ \
} \ } \
} }
@ -103,12 +110,17 @@ union slot_union {
} // namespace detail } // namespace detail
template<typename K, typename T> template<bool IsConst, typename KeyType, typename MappedType, typename Hash, typename Pred>
class hmap_iterator { class hmap_iterator {
using base_value_type = std::pair<const KeyType, MappedType>;
using slot_type = detail::slot_union<base_value_type>;
public: public:
using key_type = KeyType;
using mapped_type = MappedType;
using iterator_category = std::forward_iterator_tag; using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using value_type = std::pair<K, T>; using value_type = std::conditional_t<IsConst, std::add_const_t<base_value_type>, base_value_type>;
using pointer = value_type*; using pointer = value_type*;
using const_pointer = const value_type*; using const_pointer = const value_type*;
using reference = value_type&; using reference = value_type&;
@ -156,9 +168,7 @@ public:
} }
private: private:
friend class hmap<K, T>; friend class hmap<KeyType, MappedType, Hash, Pred>;
using slot_type = detail::slot_union<value_type>;
hmap_iterator(detail::meta_byte* mb_ptr, slot_type* slot_ptr) hmap_iterator(detail::meta_byte* mb_ptr, slot_type* slot_ptr)
: mb_ptr{mb_ptr}, slot_ptr{slot_ptr} : mb_ptr{mb_ptr}, slot_ptr{slot_ptr}
@ -184,41 +194,57 @@ private:
} }
} }
detail::meta_byte* mb_ptr = nullptr; detail::meta_byte* mb_ptr{nullptr};
slot_type* slot_ptr = nullptr; slot_type* slot_ptr{nullptr};
}; };
template<typename KeyType, typename MappedType> template<typename KeyType, typename MappedType, typename Hash = std::hash<KeyType>, typename Pred = std::equal_to<KeyType>>
class hmap { class hmap {
public: public:
using key_type = KeyType; using key_type = KeyType;
using mapped_type = MappedType; using mapped_type = MappedType;
using value_type = std::pair<key_type, mapped_type>; using hasher = Hash;
using iterator = hmap_iterator<key_type, mapped_type>; using key_equal = Pred;
using value_type = std::pair<const key_type, mapped_type>;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using iterator = hmap_iterator<false, key_type, mapped_type, hasher, key_equal>;
using const_iterator = hmap_iterator<true, key_type, mapped_type, hasher, key_equal>;
private: private:
using slot_type = detail::slot_union<value_type>; using slot_type = detail::slot_union<value_type>;
static_assert(!std::is_reference_v<key_type>); static_assert(!std::is_reference_v<key_type>);
static_assert(!std::is_reference_v<mapped_type>); static_assert(!std::is_reference_v<mapped_type>);
static constexpr size_t group_size = 16; static constexpr size_t group_size{16};
public: public:
hmap() hmap()
{ {
initialize_members(1); initialize_members(1);
} }
hmap(const hmap&);
hmap(hmap&&);
hmap& operator=(const hmap&);
hmap& operator=(hmap&&);
~hmap() ~hmap()
{ {
for (auto iter = begin(); iter != end(); ++iter) { clear();
iter->~value_type();
}
} }
[[nodiscard]] bool empty() const noexcept { return full_slots == 0; }
size_type size() const noexcept { return full_slots; }
size_type max_size() const noexcept { return static_cast<size_type>(std::numeric_limits<difference_type>::max()); }
iterator begin() iterator begin()
{ {
iterator result{get_iterator_at(0)}; iterator result{iterator_at(0)};
result.skip_empty_or_tombstone(); result.skip_empty_or_tombstone();
return result; return result;
} }
@ -226,34 +252,26 @@ public:
{ {
return {}; return {};
} }
const_iterator cbegin() const
iterator find(const key_type& key)
{ {
const size_t hash = std::hash<key_type>{}(key); const_iterator result{const_iterator_at(0)};
const detail::meta_byte mb = detail::meta_byte_from_hash(hash); result.skip_empty_or_tombstone();
return result;
size_t group_index = detail::group_index_from_hash(hash, group_index_mask); }
const_iterator cend() const
while (true) { {
detail::meta_byte_group g{mbs.get() + group_index * group_size}; return {};
}
MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { const_iterator begin() const
const size_t item_index{group_index * group_size + match_index}; {
return cbegin();
if (slots[item_index].value.first == key) [[likely]] { }
return get_iterator_at(item_index); const_iterator end() const
} {
}); return cend();
if (g.is_any_empty()) [[likely]] {
return {};
}
group_index = (group_index + 1) & group_index_mask;
}
} }
template<typename K, typename... Args> template<typename K = key_type, typename... Args>
std::pair<iterator, bool> try_emplace(K&& k, Args&&... args) std::pair<iterator, bool> try_emplace(K&& k, Args&&... args)
{ {
auto [item_index, item_found] = find_key_or_empty_slot(k); auto [item_index, item_found] = find_key_or_empty_slot(k);
@ -263,10 +281,10 @@ public:
std::forward_as_tuple(std::forward<K>(k)), std::forward_as_tuple(std::forward<K>(k)),
std::forward_as_tuple(std::forward<Args>(args)...)); std::forward_as_tuple(std::forward<Args>(args)...));
} }
return {get_iterator_at(item_index), !item_found}; return {iterator_at(item_index), !item_found};
} }
template<typename K, typename V> template<typename K = key_type, typename V = mapped_type>
std::pair<iterator, bool> insert_or_assign(K&& k, V&& v) std::pair<iterator, bool> insert_or_assign(K&& k, V&& v)
{ {
auto [item_index, item_found] = find_key_or_empty_slot(k); auto [item_index, item_found] = find_key_or_empty_slot(k);
@ -277,27 +295,40 @@ public:
std::forward<K>(k), std::forward<K>(k),
std::forward<V>(v)); std::forward<V>(v));
} }
return {get_iterator_at(item_index), !item_found}; return {iterator_at(item_index), !item_found};
} }
void erase(const_iterator position)
{
if (position == cend()) {
return;
}
const size_t item_index{static_cast<size_t>(std::distance(mbs.get(), position.mb_ptr))};
const size_t group_index{item_index / group_size};
const detail::meta_byte_group g{mbs.get() + group_index * group_size};
erase_impl(item_index, std::move(g));
}
void erase(iterator position)
{
if (position == end()) {
return;
}
const size_t item_index{static_cast<size_t>(std::distance(mbs.get(), position.mb_ptr))};
const size_t group_index{item_index / group_size};
const detail::meta_byte_group g{mbs.get() + group_index * group_size};
erase_impl(item_index, std::move(g));
}
template<typename K = key_type> template<typename K = key_type>
mapped_type& operator[](K&& k) size_t erase(const K& key)
{ {
return try_emplace(std::forward<K>(k)).first->second; const size_t hash{hasher{}(key)};
} const detail::meta_byte mb{detail::meta_byte_from_hash(hash)};
private: size_t group_index{detail::group_index_from_hash(hash, group_index_mask)};
iterator get_iterator_at(size_t item_index)
{
return {mbs.get() + item_index, slots.get() + item_index};
}
std::pair<size_t, bool> find_key_or_empty_slot(const key_type& key)
{
const size_t hash = std::hash<key_type>{}(key);
const detail::meta_byte mb = detail::meta_byte_from_hash(hash);
size_t group_index = detail::group_index_from_hash(hash, group_index_mask);
while (true) { while (true) {
detail::meta_byte_group g{mbs.get() + group_index * group_size}; detail::meta_byte_group g{mbs.get() + group_index * group_size};
@ -305,7 +336,140 @@ private:
MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), { MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), {
const size_t item_index{group_index * group_size + match_index}; const size_t item_index{group_index * group_size + match_index};
if (slots[item_index].value.first == key) [[likely]] { if (key_equal{}(slots[item_index].value.first, key)) [[likely]] {
erase_impl(item_index, std::move(g));
return 1;
}
});
if (g.is_any_empty()) [[likely]] {
return 0;
}
group_index = (group_index + 1) & group_index_mask;
}
}
template<typename K = key_type>
iterator find(const K& key)
{
const size_t hash{hasher{}(key)};
const detail::meta_byte mb{detail::meta_byte_from_hash(hash)};
size_t group_index{detail::group_index_from_hash(hash, group_index_mask)};
while (true) {
detail::meta_byte_group g{mbs.get() + group_index * group_size};
MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), {
const size_t item_index{group_index * group_size + match_index};
if (key_equal{}(slots[item_index].value.first, key)) [[likely]] {
return iterator_at(item_index);
}
});
if (g.is_any_empty()) [[likely]] {
return {};
}
group_index = (group_index + 1) & group_index_mask;
}
}
template<typename K = key_type>
const_iterator find(const K& key) const
{
const size_t hash{hasher{}(key)};
const detail::meta_byte mb{detail::meta_byte_from_hash(hash)};
size_t group_index{detail::group_index_from_hash(hash, group_index_mask)};
while (true) {
detail::meta_byte_group g{mbs.get() + group_index * group_size};
MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), {
const size_t item_index{group_index * group_size + match_index};
if (key_equal{}(slots[item_index].value.first, key)) [[likely]] {
return const_iterator_at(item_index);
}
});
if (g.is_any_empty()) [[likely]] {
return {};
}
group_index = (group_index + 1) & group_index_mask;
}
}
template<typename K = key_type>
bool contains(const K& key) const
{
return find(key) != end();
}
template<typename K = key_type>
size_t count(const K& key) const
{
return contains(key) ? 1 : 0;
}
template<typename K = key_type>
mapped_type& operator[](K&& k)
{
return try_emplace(std::forward<K>(k)).first->second;
}
template<typename K = key_type>
mapped_type& at(K&& k)
{
const auto iter{find(k)};
if (iter == end()) {
throw std::out_of_range("hmap::at: key not found");
}
return iter->second;
}
template<typename K = key_type>
const mapped_type& at(K&& k) const
{
const auto iter{find(k)};
if (iter == end()) {
throw std::out_of_range("hmap::at: key not found");
}
return iter->second;
}
void clear()
{
for (auto iter{begin()}; iter != end(); ++iter) {
iter->~value_type();
}
clear_metadata();
}
private:
iterator iterator_at(size_t item_index)
{
return {mbs.get() + item_index, slots.get() + item_index};
}
const_iterator const_iterator_at(size_t item_index)
{
return {mbs.get() + item_index, slots.get() + item_index};
}
std::pair<size_t, bool> find_key_or_empty_slot(const key_type& key)
{
const size_t hash{hasher{}(key)};
const detail::meta_byte mb{detail::meta_byte_from_hash(hash)};
size_t group_index{detail::group_index_from_hash(hash, group_index_mask)};
while (true) {
detail::meta_byte_group g{mbs.get() + group_index * group_size};
MCL_HMAP_MATCH_META_BYTE_GROUP(g.match(mb), {
const size_t item_index{group_index * group_size + match_index};
if (key_equal{}(slots[item_index].value.first, key)) [[likely]] {
return {item_index, true}; return {item_index, true};
} }
}); });
@ -324,7 +488,7 @@ private:
grow_and_rehash(); grow_and_rehash();
} }
size_t group_index = detail::group_index_from_hash(hash, group_index_mask); size_t group_index{detail::group_index_from_hash(hash, group_index_mask)};
while (true) { while (true) {
detail::meta_byte_group g{mbs.get() + group_index * group_size}; detail::meta_byte_group g{mbs.get() + group_index * group_size};
@ -335,6 +499,7 @@ private:
if (mbs[item_index] == detail::meta_byte::empty) [[likely]] { if (mbs[item_index] == detail::meta_byte::empty) [[likely]] {
--empty_slots; --empty_slots;
} }
++full_slots;
mbs[item_index] = detail::meta_byte_from_hash(hash); mbs[item_index] = detail::meta_byte_from_hash(hash);
@ -345,25 +510,38 @@ private:
} }
} }
void erase_impl(size_t item_index, detail::meta_byte_group&& g)
{
slots[item_index].value->~value_type();
--full_slots;
if (g.is_any_empty()) {
mbs[item_index] = detail::meta_byte::empty;
++empty_slots;
} else {
mbs[item_index] = detail::meta_byte::tombstone;
}
}
void grow_and_rehash() void grow_and_rehash()
{ {
const size_t new_group_count = 2 * (group_index_mask + 1); const size_t new_group_count{2 * (group_index_mask + 1)};
pow2_resize(new_group_count); pow2_resize(new_group_count);
} }
void pow2_resize(size_t new_group_count) void pow2_resize(size_t new_group_count)
{ {
auto iter = begin(); auto iter{begin()};
const auto old_mbs = std::move(mbs); const auto old_mbs{std::move(mbs)};
const auto old_slots = std::move(slots); const auto old_slots{std::move(slots)};
initialize_members(new_group_count); initialize_members(new_group_count);
for (; iter != iterator{}; ++iter) { for (; iter != iterator{}; ++iter) {
const size_t hash = std::hash<key_type>{}(iter->first); const size_t hash{hasher{}(iter->first)};
const size_t item_index = find_empty_slot_to_insert(hash); const size_t item_index{find_empty_slot_to_insert(hash)};
new (&slots[item_index].value) value_type(std::move(iter.slot_ptr->value)); new (&slots[item_index].value) value_type(std::move(iter.slot_ptr->value));
iter.slot_ptr->value.~value_type(); iter.slot_ptr->value.~value_type();
@ -375,16 +553,26 @@ private:
// DEBUG_ASSERT(group_count != 0 && std::ispow2(group_count)); // DEBUG_ASSERT(group_count != 0 && std::ispow2(group_count));
group_index_mask = group_count - 1; group_index_mask = group_count - 1;
empty_slots = group_count * group_size * 7 / 8;
mbs = std::unique_ptr<detail::meta_byte[]>{new (std::align_val_t(group_size)) detail::meta_byte[group_count * group_size + 1]}; mbs = std::unique_ptr<detail::meta_byte[]>{new (std::align_val_t(group_size)) detail::meta_byte[group_count * group_size + 1]};
slots = std::unique_ptr<slot_type[]>{new slot_type[group_count * group_size]}; slots = std::unique_ptr<slot_type[]>{new slot_type[group_count * group_size]};
clear_metadata();
}
void clear_metadata()
{
const size_t group_count{group_index_mask + 1};
empty_slots = group_count * group_size * 7 / 8;
full_slots = 0;
std::memset(mbs.get(), static_cast<int>(detail::meta_byte::empty), group_count * group_size); std::memset(mbs.get(), static_cast<int>(detail::meta_byte::empty), group_count * group_size);
mbs[group_count * group_size] = detail::meta_byte::end_sentinel; mbs[group_count * group_size] = detail::meta_byte::end_sentinel;
} }
size_t group_index_mask; size_t group_index_mask;
size_t empty_slots; size_t empty_slots;
size_t full_slots;
std::unique_ptr<detail::meta_byte[]> mbs; std::unique_ptr<detail::meta_byte[]> mbs;
std::unique_ptr<slot_type[]> slots; std::unique_ptr<slot_type[]> slots;
}; };

View File

@ -15,8 +15,11 @@ TEST_CASE("mcl::hmap", "[hmap]")
constexpr int count = 100000; constexpr int count = 100000;
REQUIRE(double_map.empty());
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
double_map[i] = i * 2; double_map[i] = i * 2;
REQUIRE(double_map.size() == i + 1);
} }
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
@ -36,4 +39,12 @@ TEST_CASE("mcl::hmap", "[hmap]")
(void)k; (void)k;
REQUIRE(v == 1); REQUIRE(v == 1);
} }
REQUIRE(!double_map.empty());
double_map.clear();
REQUIRE(double_map.empty());
for (auto [k, v] : double_map) {
REQUIRE(false);
}
} }