[C++ STL]: 如何hack一个最简单的哈希表?

故事从阿里的电话面连环问开始:
Q: 哈希表知道吧, 说一下基本的原理?
A: 您问的是普遍的哈希表公共的一些性质, 还是特指C++的unordered_map/set它们?
Q: 都可以... 处理哈希碰撞有哪些方式? 线性探测说一下, 可能有什么难点? 您提到了"开链", 请问这个链是什么链? 如果链太长了, 你会怎么设计 (说了像Java那样链条重塑成红黑树, 或者依然是线性的, 但跳表那样)?

问题多少有点热身的性质, 但挺考验对哈希表设计细节的理解. 比如笔者的确不知道"这个链是什么链", 但battle一会确认单链就足够了. 本文的目的是挖一些我们C++用到的哈希表, 更细节一点 (也没被问到) 的性质; 最后跳出C++和STL的语境, 看看其他语境下的哈希表, 其设计各有哪些独到之处.

key到哈希值的映射, 可能发生"碰撞"

哈希值到bucket映射, 也可能发生"碰撞"

哈希表第一个离不开的是哈希函数: 它提供了一种给定key, 生成映射到"定长签名"的机制. 简单来说, 一个哈希必须具备一系列性质, 一个优秀的哈希函数最好具备一些性质:
必须具备:

  • 确定性: 显然给定所有输入的参数, key计算出的哈希值必须是确定的.
  • 允许输入key长度可变.
  • 生成哈希值 (即上述的"签名") 定长.

最好具备:

  • 雪崩效应 (avalanche effect): 输入key的微小改变, 足以引起哈希值的相当大的改变. 如我们用sha256sum计算字符串"std::unordered_map<double, int>"和"ste::unordered_map<double, int>"的sha256算法哈希值, 两者只在一个位上发生了翻转.
echo -n "std::unordered_map<double, int>" | sha256sum
   >> 5f330a343b7ce65d79ba0829472e1eb0310ab780363bc1241c9e6ba3df75cca8  -
echo -n "ste::unordered_map<double, int>" | sha256sum
   >> 5b7914bc0a6b9dc996e99af1ff14795dcf7ed2c3cbaf4c930b93badeef7484e8  -
  • 均匀: 将输入空间的keys, 均匀地散布到哈希值空间的每一个取值.

想用哈希值进行key的look-up操作, 还有一些实际的考量: 定长的哈希值, 长度少则32, 64位, 256位也不是稀罕事, 所以我们肯定不能给每个可能的哈希值都分配一个我们称为bucket的空间, 而且这样的空间还得是连续的. 所以我们得允许多个不同的哈希值映射到同一个bucket. OK, 现在能具体介绍C++ STL语境下, "哈希值的计算", "bucket的定位", "元素的定位" 这三步是怎么走的了.

std::hash<_Value> 的特化

  /// See: hashtable.h
  /// ----------------
  template<typename _Key, typename _Value,
       typename _Alloc, typename _ExtractKey, typename _Equal,
       typename _H1, typename _H2, typename _Hash, typename _RehashPolicy,
       typename _Traits>
    auto
    _Hashtable<_Key, _Value, _Alloc, _ExtractKey, _Equal,
           _H1, _H2, _Hash, _RehashPolicy, _Traits>::
    find(const key_type& __k) const
    -> const_iterator
    {
      __hash_code __code = this->_M_hash_code(__k);
      std::size_t __n = _M_bucket_index(__k, __code);
      __node_type* __p = _M_find_node(__n, __k, __code);
      return __p ? const_iterator(__p) : end();
    }

我们调用find()方法的大致流程如上.

  • _M_hash_code: 计算哈希值.
  • _M_bucket_index: 找到哈希值落在的bucket index.
  • _M_find_node: bucket内找到想要的K/V.

其中第 3 行模板参数, _H1, _H2, _Hash, _RehashPolicy值得注意, 且_M_hash_code函数执行哈希值计算的逻辑. 从key计算哈希值使用std::hash<_Value>作为_H1模板参数; 从哈希值计算bucket index使用std::__detail::_Mod_range_hashing作为_H2模板参数; 使用std::__detail::_Prime_rehash_policy作为_RehashPolicy模板参数去re-hash (划重点).

  /// See: unordered_set.h
  /// --------------------
  /// Base types for unordered_set.
  template<bool _Cache>
    using __uset_traits = __detail::_Hashtable_traits<_Cache, true, true>;

  template<typename _Value,
       typename _Hash = hash<_Value>,
       typename _Pred = std::equal_to<_Value>,
       typename _Alloc = std::allocator<_Value>,
       typename _Tr = __uset_traits<__cache_default<_Value, _Hash>::value>>
    using __uset_hashtable = _Hashtable<_Value, _Value, _Alloc,
                    __detail::_Identity, _Pred, _Hash,
                    __detail::_Mod_range_hashing,
                    __detail::_Default_ranged_hash,
                    __detail::_Prime_rehash_policy, _Tr>;

哈希值的类型始终为64位的size_t, 因此是定长的. 当_value类型是比较trivial的整型数, 直接用传入的数值就能作为哈希值, 最多再执行显式转型.

  /// See: functional_hash.h
  /// ----------------------
  /// Explicit specialization for bool.
  _Cxx_hashtable_define_trivial_hash(bool)

  /// Explicit specialization for char.
  _Cxx_hashtable_define_trivial_hash(char)

  /// Explicit specialization for signed char.
  _Cxx_hashtable_define_trivial_hash(signed char)

  /// Explicit specialization for int.
  _Cxx_hashtable_define_trivial_hash(int)

  /// ...

  #define _Cxx_hashtable_define_trivial_hash(_Tp)   \
  template<>                        \
    struct hash<_Tp> : public __hash_base<size_t, _Tp>  \
    {                                                   \
      size_t                                            \
      operator()(_Tp __val) const noexcept              \
      { return static_cast<size_t>(__val); }            \
    };

  /// Explicit specialization for bool.
  _Cxx_hashtable_define_trivial_hash(bool)

但对于浮点数, 乃至std::string这样连大小都不确定, 不trivial的key, std::hash<_Value>又会怎么计算哈希值?

Murmur哈希

float为例, 特化后最终调用的_Hash_bytes函数, 在libstdc++对64位的size_t实现中如下:

  // Implementation of Murmur hash for 64-bit size_t.
  size_t _Hash_bytes(const void* ptr, size_t len, size_t seed) {
    static const size_t mul = (((size_t) 0xc6a4a793UL) << 32UL) + (size_t) 0x5bd1e995UL;
    const char* const buf = static_cast<const char*>(ptr);

    // Remove the bytes not divisible by the sizeof(size_t).  This
    // allows the main loop to process the data as 64-bit integers.
    const size_t len_aligned = len & ~(size_t)0x7;
    const char* const end = buf + len_aligned;
    size_t hash = seed ^ (len * mul);
    for (const char* p = buf; p != end; p += 8) {
      const size_t data = shift_mix(unaligned_load(p) * mul) * mul;
      hash ^= data;
      hash *= mul;
    }
    if ((len & 0x7) != 0) {
      const size_t data = load_bytes(end, len & 0x7);
      hash ^= data;
      hash *= mul;
    }
    hash = shift_mix(hash) * mul;
    hash = shift_mix(hash);
    return hash;
  }

  inline std::size_t unaligned_load(const char* p) {
    std::size_t result;
    __builtin_memcpy(&result, p, sizeof(result));
    return result;
  }

  inline std::size_t load_bytes(const char* p, int n) {
    std::size_t result = 0;
    --n;
    do {
      result = (result << 8) + static_cast<unsigned char>(p[n]);
    } while (--n >= 0);
    return result;
  }

  inline std::size_t shift_mix(std::size_t v) {
    return v ^ (v >> 47);
  }

流程如下:

  • size_t hash = seed ^ (len * mul);根据随机数, key的字节大小, 乘积因子计算初始哈希值hash.
  • 除最后一轮, 每轮读入 8 字节解释为size_t类型, 经过乘 (Mu) 和旋转 (R), 逐渐混入hash.
  • 最后一轮如果不足数, 读取剩余字节, 直接混入hash.
  • hash返回前执行后处理.

其中存在部分写死的常量: 如右移47位, 乘积因子mul, 以及特化时传入的种子__seed=0xc70f6907UL. 同时还有一个隐含结论: float类型的变量和std::string类型的变量, 如果字节布局完全一致, 则使用std::unordered_set默认哈希函数计算, 两者的哈希值是完全一致的! 我们可以验证这一点. 一并验证上述的Murmur3函数就是默认使用的哈希函数:

#include <unordered_set>
#include <cstring>
#include <iostream>
#include <functional>

namespace test {
    size_t __seed = static_cast<size_t>(0xc70f6907UL);

    inline std::size_t unaligned_load(const char* p) {
        std::size_t result;
        __builtin_memcpy(&result, p, sizeof(result));
        return result;
    }

    inline std::size_t load_bytes(const char* p, int n) {
        std::size_t result = 0;
        --n;
        do {
            result = (result << 8) + static_cast<unsigned char>(p[n]);
        } while (--n >= 0);
        return result;
    }

    inline std::size_t shift_mix(std::size_t v) {
        return v ^ (v >> 47);
    }

    size_t _Hash_bytes(const void* ptr, size_t len, size_t seed) {
        static const size_t mul = (static_cast<size_t>(0xc6a4a793UL) << 32UL) 
            + static_cast<size_t>(0x5bd1e995UL);
        // <=> static const size_t mul = 0xc6a4a7935bd1e995UL;
        const char* const buf = static_cast<const char*>(ptr);

        // Remove the bytes not divisible by the sizeof(size_t).  This
        // allows the main loop to process the data as 64-bit integers.
        const size_t len_aligned = len & ~static_cast<size_t>(0x7);
        const char* const end = buf + len_aligned;
        size_t hash = seed ^ (len * mul);
        for (const char* p = buf; p != end; p += 8) {
            // load and interpret 8-byte as a `size_t` type value.
            const size_t data = shift_mix(unaligned_load(p) * mul) * mul;
            hash ^= data;
            hash *= mul;
        }
        if ((len & 0x7) != 0) {
            const size_t data = load_bytes(end, len & 0x7);
            hash ^= data;
            hash *= mul;
        }
        hash = shift_mix(hash) * mul;
        hash = shift_mix(hash);
        return hash;
    }
}  // namespace test



int main() {
    float val_float = 100.1;
    void *ptr = static_cast<void*>(&val_float);

    char buff[10];
    buff[sizeof(float)] = '\0';
    memcpy(&buff, &val_float, sizeof(float));
    std::string str = buff;

    std::unordered_set<float> float_hash_tab;
    std::unordered_set<std::string> string_hash_tab;

    std::cout << "hash as `float`:        " << float_hash_tab.hash_function()(val_float) << std::endl;
    std::cout << "hash as `std::string`:  " << string_hash_tab.hash_function()(str) << std::endl;
    std::cout << "using std::hash<float>: " << std::hash<float>()(val_float) << std::endl;
    std::cout << "using murmur directly:  " << test::_Hash_bytes(ptr, 4, test::__seed) << std::endl;
    // 都是 `11097413356452607058`.
    return 0;
}

如何找到对应的bucket和K/V?

std::__detail::_Mod_range_hashing用了取余的简单方法, 完成从哈希值到bucket index的映射. 其中_M_bucket_count维护了bucket计数. 在既定的bucket内部, 遍历直到第一个元素匹配key.

  /// See: hashtable_policy.h
  /// -----------------------  
  /// Default range hashing function: use division to fold a large number
  /// into the range [0, N).
  struct _Mod_range_hashing
  {
    typedef std::size_t first_argument_type;
    typedef std::size_t second_argument_type;
    typedef std::size_t result_type;

    result_type
    operator()(first_argument_type __num,
           second_argument_type __den) const noexcept
    { return __num % __den; }
  };

  std::size_t
  _M_bucket_index(const _Key& __k, __hash_code, std::size_t __n) const
  { return _M_ranged_hash()(__k, __n); }

  /// See: hashtable.h
  /// ----------------
  size_type
  _M_bucket_index(const key_type& __k, __hash_code __c) const
  { return __hash_code_base::_M_bucket_index(__k, __c, _M_bucket_count); }

  __node_type*
  _M_find_node(size_type __bkt, const key_type& __key, __hash_code __c) const {
    __node_base* __before_n = _M_find_before_node(__bkt, __key, __c);
    if (__before_n)
      return static_cast<__node_type*>(__before_n->_M_nxt);
    return nullptr;
  }

扩容

哈希表维护当前bucket和element分别的数目, 每次插入时会据此计算装填因子. _M_need_rehash判定是否需要re-hash:
假定插入成功后, 新的element数目超过了触发下次扩容的阈值, 当前bucket数目已不足以维护装填因子小于阈值, 则以一定的扩容因子对bucket扩容. 扩容因子_M_growth_factor为2; 默认装填因子_M_max_load_factor为1.

   *  Each _Hashtable data structure has:
   *
   *  - _Bucket[]       _M_buckets
   *  - _Hash_node_base _M_before_begin
   *  - size_type       _M_bucket_count
   *  - size_type       _M_element_count
   
  // Insert node, in bucket bkt if no rehash (assumes no element with its key
  // already present). Take ownership of the node, deallocate it on exception.
  template<typename _Key, typename _Value,
       typename _Alloc, typename _ExtractKey, typename _Equal,
       typename _H1, typename _H2, typename _Hash, typename _RehashPolicy,
       typename _Traits>
    auto
    _Hashtable<_Key, _Value, _Alloc, _ExtractKey, _Equal,
           _H1, _H2, _Hash, _RehashPolicy, _Traits>::
    _M_insert_multi_node(__node_type* __hint, __hash_code __code,
             __node_type* __node)
    -> iterator
    {
      const __rehash_state& __saved_state = _M_rehash_policy._M_state();
      std::pair<bool, std::size_t> __do_rehash
    = _M_rehash_policy._M_need_rehash(_M_bucket_count, _M_element_count, 1);
        
  // See hashtable_policy.h [tr1]
  // ----------------------
  // Default value for rehash policy.  Bucket size is (usually) the
  // smallest prime that keeps the load factor small enough.
  struct _Prime_rehash_policy {
    _Prime_rehash_policy(float __z = 1.0)
    : _M_max_load_factor(__z), _M_growth_factor(2.f), _M_next_resize(0) { }
  }
  extern const unsigned long __prime_list[];
        
  // Return a prime no smaller than n.
  inline std::size_t
  _Prime_rehash_policy::
  _M_next_bkt(std::size_t __n) const
  {
    // Don't include the last prime in the search, so that anything
    // higher than the second-to-last prime returns a past-the-end
    // iterator that can be dereferenced to get the last prime.
    const unsigned long* __p
      = std::lower_bound(__prime_list, __prime_list + _S_n_primes - 1, __n);
    _M_next_resize = 
      static_cast<std::size_t>(__builtin_ceil(*__p * _M_max_load_factor));
    return *__p;
  }

  // Finds the smallest prime p such that alpha p > __n_elt + __n_ins.
  // If p > __n_bkt, return make_pair(true, p); otherwise return
  // make_pair(false, 0).  In principle this isn't very different from 
  // _M_bkt_for_elements.

  // The only tricky part is that we're caching the element count at
  // which we need to rehash, so we don't have to do a floating-point
  // multiply for every insertion.

  inline std::pair<bool, std::size_t>
  _Prime_rehash_policy::
  _M_need_rehash(std::size_t __n_bkt, std::size_t __n_elt,
         std::size_t __n_ins) const
  {
    if (__n_elt + __n_ins > _M_next_resize)
      {
    float __min_bkts = ((float(__n_ins) + float(__n_elt))
                / _M_max_load_factor);
    if (__min_bkts > __n_bkt)
      {
        __min_bkts = std::max(__min_bkts, _M_growth_factor * __n_bkt);
        return std::make_pair(true,
                  _M_next_bkt(__builtin_ceil(__min_bkts)));
      }
    else 
      {
        _M_next_resize = static_cast<std::size_t>
          (__builtin_ceil(__n_bkt * _M_max_load_factor));
        return std::make_pair(false, 0);
      }
      }
    else
      return std::make_pair(false, 0);
  }

有个存疑的问题: 怎么知道素数表__prime_list[]在哪个文件定义? 据说在hashtable-aux.cc源文件, 但笔者的本地并没有. 扩容时, bucket index的转移是"无记忆"的, 笔者想表达的是什么意思呢?

假设扩容前有 \alpha 个bucket B_1[0]\sim B_{1}[\alpha-1], 扩容后有 \beta 个bucketB_2[0]\sim B_{2}[\beta-1]\beta\approx2\alpha. 假设key kB_{1}下最终映射的bucket index为 i, 0\leq i<\alpha; 那么re-hash后, 在B_{2}下新映射的bucket index j 只和既定的H_1, H_2函数有关, 和历史状态 i 无关.

开始Hack

现在我们能构造一种情况, 让大量的K/V都落在同一个bucket里了.

#include <unordered_map>
#include <iostream>

const int N = 1e5;

extern const unsigned long __prime_list[] = // 256 + 1 or 256 + 48 + 1
  {
    2ul, 3ul, 5ul, 7ul, 11ul, 13ul, 17ul, 19ul, 23ul, 29ul, 31ul,
    37ul, 41ul, 43ul, 47ul, 53ul, 59ul, 61ul, 67ul, 71ul, 73ul, 79ul,
    83ul, 89ul, 97ul, 103ul, 109ul, 113ul, 127ul, 137ul, 139ul, 149ul,
    157ul, 167ul, 179ul, 193ul, 199ul, 211ul, 227ul, 241ul, 257ul,
    277ul, 293ul, 313ul, 337ul, 359ul, 383ul, 409ul, 439ul, 467ul,
    503ul, 541ul, 577ul, 619ul, 661ul, 709ul, 761ul, 823ul, 887ul,
    953ul, 1031ul, 1109ul, 1193ul, 1289ul, 1381ul, 1493ul, 1613ul,
    1741ul, 1879ul, 2029ul, 2179ul, 2357ul, 2549ul, 2753ul, 2971ul,
    3209ul, 3469ul, 3739ul, 4027ul, 4349ul, 4703ul, 5087ul, 5503ul,
    5953ul, 6427ul, 6949ul, 7517ul, 8123ul, 8783ul, 9497ul, 10273ul,
    11113ul, 12011ul, 12983ul, 14033ul, 15173ul, 16411ul, 17749ul,
    19183ul, 20753ul, 22447ul, 24281ul, 26267ul, 28411ul, 30727ul,
    33223ul, 35933ul, 38873ul, 42043ul, 45481ul, 49201ul, 53201ul,
    57557ul, 62233ul, 67307ul, 72817ul, 78779ul, 85229ul, 92203ul,
    99733ul, 107897ul, 116731ul, 126271ul, 136607ul, 147793ul,
    159871ul, 172933ul, 187091ul, 202409ul, 218971ul, 236897ul,
    256279ul, 277261ul, 299951ul, 324503ul, 351061ul, 379787ul,
    410857ul, 444487ul, 480881ul, 520241ul, 562841ul, 608903ul,
    658753ul, 712697ul, 771049ul, 834181ul, 902483ul, 976369ul,
    1056323ul, 1142821ul, 1236397ul, 1337629ul, 1447153ul, 1565659ul,
    1693859ul, 1832561ul, 1982627ul, 2144977ul, 2320627ul, 2510653ul,
    2716249ul, 2938679ul, 3179303ul, 3439651ul, 3721303ul, 4026031ul,
    4355707ul, 4712381ul, 5098259ul, 5515729ul, 5967347ul, 6456007ul,
    6984629ul, 7556579ul, 8175383ul, 8844859ul, 9569143ul, 10352717ul,
    11200489ul, 12117689ul, 13109983ul, 14183539ul, 15345007ul,
    16601593ul, 17961079ul, 19431899ul, 21023161ul, 22744717ul,
    24607243ul, 26622317ul, 28802401ul, 31160981ul, 33712729ul,
    36473443ul, 39460231ul, 42691603ul, 46187573ul, 49969847ul,
    54061849ul, 58488943ul, 63278561ul, 68460391ul, 74066549ul,
    80131819ul, 86693767ul, 93793069ul, 101473717ul, 109783337ul,
    118773397ul, 128499677ul, 139022417ul, 150406843ul, 162723577ul,
    176048909ul, 190465427ul, 206062531ul, 222936881ul, 241193053ul,
    260944219ul, 282312799ul, 305431229ul, 330442829ul, 357502601ul,
    386778277ul, 418451333ul, 452718089ul, 489790921ul, 529899637ul,
    573292817ul, 620239453ul, 671030513ul, 725980837ul, 785430967ul,
    849749479ul, 919334987ul, 994618837ul, 1076067617ul, 1164186217ul,
    1259520799ul, 1362662261ul, 1474249943ul, 1594975441ul, 1725587117ul,
    1866894511ul, 2019773507ul, 2185171673ul, 2364114217ul, 2557710269ul,
    2767159799ul, 2993761039ul, 3238918481ul, 3504151727ul, 3791104843ul,
    4101556399ul, 4294967291ul,
  };

void insert_numbers(u_int64_t x) {
    clock_t begin = clock();
    std::unordered_map<u_int64_t, int> numbers;

    for (int i = 1; i <= N; i++) {
        numbers[i * x] = i;
    }
    
    u_int64_t sum = 0;
    for (auto &entry : numbers) {
        sum += (entry.first / x) * entry.second;
    }
    
    printf("********\n");
    printf("x = %lld: %.3lf seconds, sum = %lld\n", x, 
           static_cast<double>(clock() - begin) / CLOCKS_PER_SEC, sum);
    printf("bucket count = %lld\n", numbers.bucket_count());
    printf("********\n");
    return;
}


int main() {
    insert_numbers(5087);
    insert_numbers(10273);
    insert_numbers(20753);  // !
    insert_numbers(351061);
    insert_numbers(172933); // !!
    insert_numbers(172935);
    
    return 0;
}

注意到同样插入1e5个数, 耗时差距极大

如果我们能估算出bucket最终增长到的数目, 并且哈希值的计算过程简单, 我们就能构造出大量的key, 它们的哈希值最终落在同一个bucket里.
这就是x = 172933做的事, 因为std::hash<u_int64_t>特化后直接用key值当其哈希值了. 这里的耗时集中在第一个赋值循环: 因为不是multimap, 每次插入需遍历既定的bucket; 但第二个遍历累加几乎不耗这是笔者没预想到的, 试图用perf去分析但没成功. 此外x=20753行的结果也好理解: 存在一段时间的中间状态, bucket总数为20753, 在此期间有相当多的元素落在同一个bucket内. 当元素数目进一步增加bucket扩容 (因为最终的bucket计数为172933), 暂时落在同一个bucket内的元素会重新分散到不同的buckets.

上述STL哈希表有什么问题? Golang的map怎么设计的?

  • 哈希值计算过于简单, 特别是我们成功hack的整型key. 当然这里有trade-off, 一般认为浮点数其实也不太适合作key. 但是所有的随机变量在实现中都写死了, 多少有点让人不安.
  • 过于依赖"素数"降低哈希值映射到bucket index的碰撞. 但实际在哈希值不够均匀的情况下, 这不管用.
  • 基于负载因子的re-hash策略过于死板. 它无力识别这样的负载因子尚可, 但实际病态的如下情况: 一个bucket承载了几乎所有的元素, 其余的bucket负载几乎为零, 且bucket总数大于元素总数. 从宏观的负载因子考虑, 它会依然认为局面良好.
  • "开链"当链条长度很长时, 效果太差. 插入最坏情况下得顺着单向迭代器遍历整个bucket, 局部性极差.

Golang map的设计, 笔者认为在上述 4 个方面都有一定程度的优化. 其基本结构如下:

// See runtime/map.go
// ------------------
// This file contains the implementation of Go's map type.
//
// A map is just a hash table. The data is arranged
// into an array of buckets. Each bucket contains up to
// 8 key/elem pairs. The low-order bits of the hash are
// used to select a bucket. Each bucket contains a few
// high-order bits of each hash to distinguish the entries
// within a single bucket.
// A header for a Go map.
type hmap struct {
    // Note: the format of the hmap is also encoded in cmd/compile/internal/reflectdata/reflect.go.
    // Make sure this stays in sync with the compiler's definition.
    count     int // # live cells == size of map.  Must be first (used by len() builtin)
    flags     uint8
    B         uint8  // log_2 of # of buckets (can hold up to loadFactor * 2^B items)
    noverflow uint16 // approximate number of overflow buckets; see incrnoverflow for details
    hash0     uint32 // hash seed

    buckets    unsafe.Pointer // array of 2^B Buckets. may be nil if count==0.
    oldbuckets unsafe.Pointer // previous bucket array of half the size, non-nil only when growing
    nevacuate  uintptr        // progress counter for evacuation (buckets less than this have been evacuated)

    extra *mapextra // optional fields
}

更随机化的哈希函数

一个容易注意到的特点是显式地维护了哈希表的种子hash0, 每个哈希表变量独立维护: 我们回想一下C++ STL在特化的时候直接把随机数写死了. 种子在建表的时候初始化, 当表变空时更新 (注意fastrand()函数).

// See runtime/map.go
// ------------------
// makemap implements Go map creation for make(map[k]v, hint).
// If the compiler has determined that the map or the first bucket
// can be created on the stack, h and/or bucket may be non-nil.
// If h != nil, the map can be created directly in h.
// If h.buckets != nil, bucket pointed to can be used as the first bucket.
func makemap(t *maptype, hint int, h *hmap) *hmap {
    mem, overflow := math.MulUintptr(uintptr(hint), t.Bucket.Size_)
    if overflow || mem > maxAlloc {
        hint = 0
    }

    // initialize Hmap
    if h == nil {
        h = new(hmap)
    }
    h.hash0 = fastrand()
}

func mapdelete(t *maptype, h *hmap, key unsafe.Pointer) {
    // ...
    h.count--
    // Reset the hash seed to make it more difficult for attackers to
    // repeatedly trigger hash collisions. See issue 25237.
    if h.count == 0 {
        h.hash0 = fastrand()
    }
    break search
    // ...
}

// mapclear deletes all keys from a map.
func mapclear(t *maptype, h *hmap) {
    // ...
    // Reset the hash seed to make it more difficult for attackers to
    // repeatedly trigger hash collisions. See issue 25237.
    h.hash0 = fastrand()
    // ...
}

对于浮点数key, 如float64, 结合hash0后使用如下哈希函数:

// See runtime/algo.go
// -------------------
func f64hash(p unsafe.Pointer, h uintptr) uintptr {
    f := *(*float64)(p)
    switch {
    case f == 0:
        return c1 * (c0 ^ h) // +0, -0
    case f != f:
        return c1 * (c0 ^ h ^ uintptr(fastrand())) // any kind of NaN
    default:
        return memhash(p, h, 8)
    }
}

// runtime variable to check if the processor we're running on
// actually supports the instructions used by the AES-based
// hash implementation.
var useAeshash bool

// in asm_*.s
func memhash(p unsafe.Pointer, h, s uintptr) uintptr
func memhash32(p unsafe.Pointer, h uintptr) uintptr
func memhash64(p unsafe.Pointer, h uintptr) uintptr
func strhash(p unsafe.Pointer, h uintptr) uintptr

memhash这个哈希函数的具体逻辑, 估计目前很难看懂了. 而且笔者目前不知道如何像C++的hash_function()仿函数一样, 指定哈希表对象与key值后, 获取对应的哈希值 (悲).

更友好的bucket映射方式与结构设计

下一步, 怎么根据哈希值找到bucket index? Golang的做法是mask哈希值取低位直接作为bucket index, 被mask掉的高位另有用途. 这里每个bucket的索引是连续存储的, 看成一个数组, hmap的成员buckets是指向该数组的指针, 结合mask出的offset直接找到目标bucket. 也因此完全破除了bucket数目对大素数的依赖. 在map中, bucket的数目一般为2的幂, 方便压缩编码也方便mask.

// See runtime/map.go
// ------------------
// mapaccess1 returns a pointer to h[key].  Never returns nil, instead
// it will return a reference to the zero object for the elem type if
// the key is not in the map.
// NOTE: The returned pointer may keep the whole map live, so don't
// hold onto it for very long.
func mapaccess1(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
    // ...
    hash := t.Hasher(key, uintptr(h.hash0))
    m := bucketMask(h.B)
    b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.BucketSize)))
    if c := h.oldbuckets; c != nil {
        if !h.sameSizeGrow() {
            // There used to be half as many buckets; mask down one more power of two.
            m >>= 1
        }
        oldb := (*bmap)(add(c, (hash&m)*uintptr(t.BucketSize)))
        if !evacuated(oldb) {
            b = oldb
        }
    }
    top := tophash(hash)
    // ...
}
// See runtime/map.go
// ------------------
// A bucket for a Go map.
type bmap struct {
    // tophash generally contains the top byte of the hash value
    // for each key in this bucket. If tophash[0] < minTopHash,
    // tophash[0] is a bucket evacuation state instead.
    tophash [bucketCnt]uint8
    // Followed by bucketCnt keys and then bucketCnt elems.
    // NOTE: packing all the keys together and then all the elems together makes the
    // code a bit more complicated than alternating key/elem/key/elem/... but it allows
    // us to eliminate padding which would be needed for, e.g., map[int64]int8.
    // Followed by an overflow pointer.
}

// A hash iteration structure.
// If you modify hiter, also change cmd/compile/internal/reflectdata/reflect.go
// and reflect/value.go to match the layout of this structure.
type hiter struct {
    key         unsafe.Pointer // Must be in first position.  Write nil to indicate iteration end (see cmd/compile/internal/walk/range.go).
    elem        unsafe.Pointer // Must be in second position (see cmd/compile/internal/walk/range.go).
    t           *maptype
    h           *hmap
    buckets     unsafe.Pointer // bucket ptr at hash_iter initialization time
    bptr        *bmap          // current bucket
    overflow    *[]*bmap       // keeps overflow buckets of hmap.buckets alive
    oldoverflow *[]*bmap       // keeps overflow buckets of hmap.oldbuckets alive
    startBucket uintptr        // bucket iteration started at
    offset      uint8          // intra-bucket offset to start from during iteration (should be big enough to hold bucketCnt-1)
    wrapped     bool           // already wrapped around from end of bucket array to beginning
    B           uint8
    i           uint8
    bucket      uintptr
    checkBucket uintptr
}

    // Maximum number of key/elem pairs a bucket can hold.
    bucketCntBits = abi.MapBucketCountBits
    bucketCnt     = abi.MapBucketCount

// See src/internal/abi/map.go
// ---------------------------
// Map constants common to several packages
// runtime/runtime-gdb.py:MapTypePrinter contains its own copy
const (
    MapBucketCountBits = 3 // log2 of number of elements in a bucket.
    MapBucketCount     = 1 << MapBucketCountBits
    MapMaxKeyBytes     = 128 // Must fit in a uint8.
    MapMaxElemBytes    = 128 // Must fit in a uint8.
)

哈希值的最高8位 (被mask掉的) 作为top, 有助于在bucket内快速判断原始的key是否出现过: 在STL中我们可得遍历整个链表. top值存储在capacity固定的数组bmap.tophash中, capacity写死为8限制了单个bucket的承载能力. 紧接着bmap.tophash数组的, 是指示key, value且capacity同样固定的指针数组. 如下当我们试图插入k/v, 如果当前对应的bmap没写满到其capacity, i指示了下一个应被写入的slot. dataOffset指示了key指针数组的起始地址, 从指针计算操作也能发现这三个数组都定长连续存储.

// See runtime/map.go
// ------------------
// Like mapaccess, but allocates a slot for the key if it is not present in the map.
func mapassign(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
// ...
bucketloop: 
    for i := uintptr(0); i < bucketCnt; i++ {
            if b.tophash[i] != top {
                if isEmpty(b.tophash[i]) && inserti == nil {
                    inserti = &b.tophash[i]
                    insertk = add(unsafe.Pointer(b), dataOffset+i*uintptr(t.KeySize))
                    elem = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize))
                }
                if b.tophash[i] == emptyRest {
                    break bucketloop
                }
                continue
            }
        }
// ...
}

感谢Draven佬的高质量图片

OK, 一个bmap作为bucket至多存8对K/V, 写满了还要写这个bucket怎么办?

// See runtime/map.go
// ------------------
// Like mapaccess, but allocates a slot for the key if it is not present in the map.
func mapassign(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
    // ...
    if inserti == nil {
        // The current bucket and all the overflow buckets connected to it are full, allocate a new one.
        newb := h.newoverflow(t, b)
        inserti = &newb.tophash[0]
        insertk = add(unsafe.Pointer(newb), dataOffset)
        elem = add(insertk, bucketCnt*uintptr(t.KeySize))
    }
    // ...
}

func (h *hmap) newoverflow(t *maptype, b *bmap) *bmap {
    var ovf *bmap
    // ...
    // b --> ovf
    b.setoverflow(t, ovf)
    return ovf
}

func (b *bmap) setoverflow(t *maptype, ovf *bmap) {
    *(**bmap)(add(unsafe.Pointer(b), uintptr(t.BucketSize)-goarch.PtrSize)) = ovf
}

bmap b创建新的bmap ovf, 称ovf为b的overflow bucket. b除了三个上述的数组外 (key和value指针数组是不是真的定长, 是不是都永远只存指针不存数据笔者暂时不确定), 最后还有一个约8 byte的域, 存放指向ovf的指针; 完整的bmap bucket size为t.BucketSize. 这样若干个bmap节点就串成了个链表, 其中只有开头的节点占据最高层的hmap结构的slot计数 (即hmap.buckets数组的条目), 后面的节点贡献overflow计数 (即hamp.noverflow标量).

感谢Draven佬的高质量图片

OK, 最后我们看看它会在什么时候触发re-hash?

// Like mapaccess, but allocates a slot for the key if it is not present in the map.
func mapassign(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
    // Did not find mapping for key. Allocate new cell & add entry.

    // If we hit the max load factor or we have too many overflow buckets,
    // and we're not already in the middle of growing, start growing.
    if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
        hashGrow(t, h)
        goto again // Growing the table invalidates everything, so try again
    }
}

// overLoadFactor reports whether count items placed in 1<<B buckets is over loadFactor.
func overLoadFactor(count int, B uint8) bool {
    return count > bucketCnt && uintptr(count) > loadFactorNum*(bucketShift(B)/loadFactorDen)
}

// Maximum average load of a bucket that triggers growth is bucketCnt*13/16 (about 80% full)
// Because of minimum alignment rules, bucketCnt is known to be at least 8.
// Represent as loadFactorNum/loadFactorDen, to allow integer math.
loadFactorDen = 2
loadFactorNum = (bucketCnt * 13 / 16) * loadFactorDen

// tooManyOverflowBuckets reports whether noverflow buckets is too many for a map with 1<<B buckets.
// Note that most of these overflow buckets must be in sparse use;
// if use was dense, then we'd have already triggered regular map growth.
func tooManyOverflowBuckets(noverflow uint16, B uint8) bool {
    // If the threshold is too low, we do extraneous work.
    // If the threshold is too high, maps that grow and shrink can hold on to lots of unused memory.
    // "too many" means (approximately) as many overflow buckets as regular buckets.
    // See incrnoverflow for more details.
    if B > 15 {
        B = 15
    }
    // The compiler doesn't see here that B < 16; mask B to generate shorter shift code.
    return noverflow >= uint16(1)<<(B&15)
}

两个条件满足其一:

  • hit the max load factor.
  • we have too many overflow buckets.

对于第一点, 展开完了就是count>\frac{bucketCnt\times13}{16}\times2\times\frac{cnt_{bucket}}{loadFactorDen}. 其中bucketCnt=8, loadFactorDen=2, 也即负载因子>\frac{13}{2}时触发扩容 (但\frac{bucketCnt\times13}{16}这个计算笔者看过了应该是整型数啊, 为什么很多文章说是6.5?). 这个"80% full"是不是有点thumb-up-rule就不得而知了.
第二点价值更大: 当overflow bucket多于链表的条数, 就认为"have too many overflow buckets". 这就规避了我们之前hack STL哈希表所利用的情况.

总结一下Golang map优秀的一些设计理念:

  • 更随机化的种子与哈希函数, 同时隐藏部分接口. 导致更难预测key对应的哈希值.
  • bucket"连续+离散"的设计: 单个hmap内部是连续的, 迅速遍历tophash有效应对哈希不命中的普遍情况. hmap和它的overflow bucket单链表离散串连赋予扩展性. "连续"的部分, 其长度开始就可以以固定的模式设计.
  • 更好的re-hash机制: 既考虑hmap平均装载量 (即"80%"); 又考虑链表的长度之和 (overflow机制).

后续

  • 其他语言或项目的哈希表? 如Redisset, Linux kernel内网络子系统用的哈希表 (如半连接队列)?
  • 怎么hack STL中, key是其他形式特化的哈希表? 比如key是double, std::string? 能不能直接hack murmur函数?
  • 哈希表的迭代器是怎么跑的. 既然是无序的, 那么范围遍历时, 具体是怎么做的?

参考

1: hack unordered_map的简单讨论. https://codeforces.com/blog/entry/126463
2: 理解 Golang 哈希表 Map的原理 (致谢高质量配图). https://draveness.me/golang/docs/part2-foundation/ch03-datastructure/golang-hashmap/
3: libstdc++中哈希函数的实现. https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B-v3/libsupc%2B%2B/hash_bytes.cc
4: 关于STL hashtable不错的帖子. https://www.cnblogs.com/Commence/p/9057364.html

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容