Swisstable Hash に使われているビット演算の魔術

Googleが開発したSwisstableと呼ばれるハッシュテーブル実装がAbseilとして公開されて、Rustの標準のHashMap実装にもその移植であるhashbrownが採用されました。

Swisstable の面白いところは、8または16要素をグループ化して、グループ内の各要素のハッシュ値のうち7bitをそれぞれ1byteに格納した8または16バイトの配列を作り、その配列に対して一気に並列でマッチングを行うことです。

この並列マッチングにはSSE2もしくはビット演算が使われます。この記事ではこの並列マッチング部分について解説します。

SSE2を使う場合

SSE2を使う場合は、グループのサイズは16になります。ハッシュ値を格納する配列のことを control と呼ぶことにすると、 control は char control[16] になります。control の各バイトの状態は次のようになります。

  • 0-127: 要素が入っているとき、そのハッシュ値のうち7bitを示します。
  • 0x80: 空 (EMPTY)
  • 0xFE: 削除済み (DELETED)

まず、要素を検索するときは、その要素のハッシュ値のうち7bit (上位か下位を切り取ることが多い) を char h2 = hash & 0x7f; のように切り出して、control の中で h2 に一致する要素を走査します。C言語擬似コードを書くと次のようになります。

for (uint64_t found = match(control, h2); found; next_pos(&found)) {
    int pos = first_pos(found);
    assert(control[pos] == h2);
    // 要素の配列の pos 番目と検索しているキーを比較する
}

match, first_pos, next_pos の実装を見ていきます。1行ずつコメントを入れたので読んでみてください。

uint64_t match(const char *control, char c) {
    // controlを __m128i に変換する. 16byte alignment に注意
    __m128i v = *(__m128i*)control;
    // c を16回繰り返した __m128i を作る
    __m128i w = _mm_set1_epi8(c);
    // vとwをバイトごとに比較し、一致したバイトを0xffに、しなかったバイトを0x00にする
    __m128i x = _mm_cmpeq_epi8(v, w);
    // 各バイトの符号ビット(最上位bit)を集めてint型にする
    return _mm_movemask_epi8(x);
}

int first_pos(uint64_t x) {
    // 最下位bitから、0のbitが何個あるか(=一番下の1のbitの位置)を返す。
    // 簡単のために gcc のビルトイン命令を使用してる。
    // MSVCなら _BitScanForward64() を使う
    return __builtin_ctzll(x);
}

void next_found(uint64_t *x) {
    // 一番下の1のbitを0に反転させる。
    *x &= (*x - 1);
}

実際に試してみましょう。

#include <stdio.h>
#include <stdint.h>
#include <emmintrin.h>
#include <mmintrin.h>

uint64_t match(const char *control, char c) {
    // controlを __m128i に変換する. 16byte alignment に注意
    __m128i v = *(__m128i*)control;
    // c を16回繰り返した __m128i を作る
    __m128i w = _mm_set1_epi8(c);
    // vとwをバイトごとに比較し、一致したバイトを0xffに、しなかったバイトを0x00にする
    __m128i x = _mm_cmpeq_epi8(v, w);
    // 各バイトの符号ビット(最上位bit)を集めてint型にする
    return _mm_movemask_epi8(x);
}

int first_pos(uint64_t x) {
    // 最下位bitから、0のbitが何個あるか(=一番下の1のbitの位置)を返す。
    // 簡単のために gcc のビルトイン命令を使用してる。
    // MSVCなら _BitScanForward64() を使う
    return __builtin_ctzll(x);
}

void next_pos(uint64_t *x) {
    // 一番下の1のbitを0に反転させる。
    *x &= (*x - 1);
}

int main() {
    char control[16] = {
        0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
        0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17};

    for (uint64_t found = match(control, 0x13); found; next_pos(&found)) {
        printf("found = %lx\n", found);
        int pos = first_pos(found);
        printf("pos = %d\n", pos);
    }

    return 0;
}

出力:

found = 808
pos = 3
found = 800
pos = 11

配列の3番目と13番目に 0x13 が入っていることを見つけられました。

SSE2を使わない場合

今時SSE2が使えないCPU?と思われるかもしれませんが、ARMも広く使われています(NEONへの最適化は今のところうまくいっていないようです)し、今後はRISC-Vも重要になってくるでしょう。 なので、古典的なビット演算の魔術が現代でも重宝されるのです。面白いですね。

SSE2を使わない場合は、 __m128i の代わりに int64_t を使い、グループのサイズも8に減ります。 for 文の部分は同じなので、 match, first_pos, next_pos の実装がどうなるかをみていきましょう。

#include <stdio.h>
#include <stdint.h>

static const uint64_t LSB = 0x0101010101010101ul;
static const uint64_t MSB = 0x8080808080808080ul;

uint64_t match(const char *control, char c) {
    // controlを uint64_t に変換する. alignment に注意
    uint64_t v = *(uint64_t*)control;
    // c を8回繰り返した int64_t を作る
    uint64_t w = LSB * c;
    // vとwをbitごとに比較する。
    // バイト単位でみたとき、一致したら0に、一致しなかったら0以外になる
    uint64_t x = v ^ w;
    // 各バイトが0なら0x80に、それ以外なら0x00にする。
    return (x-LSB) & (~x & MSB);
}

int first_pos(uint64_t x) {
    // 最下位bitから、0のbitが何個あるか(=一番下の1のbitの位置)を返す。
    // 簡単のために gcc のビルトイン命令を使用してる。
    // MSVCなら _BitScanForward64() を使う
    //
    // x = 0x80 の場合、 __builtin_ctzll(x) は 7 になり、7>>3 = 1 になる.
    // x = 0x8000 の場合、 __builtin_ctzll(x) は 15 になり、15>>3 = 1 になる.
    return __builtin_ctzll(x) >> 3;
}

void next_pos(uint64_t *x) {
    // 一番下の1のbitを0に反転させる。
    *x &= (*x - 1);
}


int main() {
    char control[16] = {0x10, 0x11, 0x12, 0x13, 0x14, 0x13, 0x12, 0x11};

    for (uint64_t found = match(control, 0x13); found; next_pos(&found)) {
        printf("found = %lx\n", found);
        int pos = first_pos(found);
        printf("pos = %d\n", pos);
    }

    return 0;
}

match() の最後の、 (x-LSB) & (~x & MSB) が難しいのでもう少し噛み砕きます。 各バイトに注目したとき、 LSB は 0x01 になっているので、 x が 0x00 なら 0xff になり、最上位bitが立ちます。しかし x が 0x81 以上の時も -1 しても最上位bitが立っているので、それだけでは判定できません。 右辺の (~x & MSB) で、MSBは0x80なので、xの最上位bit (0x80) が1のときは 0x00 になり、最上位ビットが0のときは0x80になります。 この左辺と右辺を & することで、各バイトが0のときだけ 0x80 になるのです。

さて、このプログラムを実行してみましょう。

found = 80800080000000
pos = 3
found = 80800000000000
pos = 5
found = 80000000000000
pos = 6

0x13 は 3, 5 番目だけなのに、 6 番目も検出されてしまっています!!!

これは、matchがSIMDでないために起こっています。pos=5,6がmatch()でどう処理されるかを見ていきましょう。

  1. control[6], control[5] は 0x1213 になっています(アドレスが低い方が低位バイトになってるので左右逆転します)
  2. 0x1213 ^ 0x1313 = 0x0100 になります。(一致している0x13の部分がちゃんと00になっている)
  3. 0x0100 - 0x0101 = 0xffff になります。 (0x01 の部分のbitが繰り下がりに消費されてしまっています。)
  4. ~0x0100 & 0x8080 = 0x8080
  5. 0x8080 & 0xffff = 0x8080

このように、bit演算は問題ないのですが、引き算によってバイト単位の処理が壊れてしまっていることがわかります。

この誤検出(false positive)が発生するのは、比較対象と最下位ビット (0x01) だけの違いしかなかくて、かつ隣の下位バイトでfalse positiveまたはtrue positiveが発生して繰り下がりが発生する場合に起こります。 (0x0200 - 0x0101 = 0x00ff; 0x00ff & 0x8080 = 0x0080)

ハッシュテーブルの探索でも1/128の確率でこの誤検出が発生しますが、これは大した問題ではありません。

通常のハッシュテーブルでは、大きさ8あたりハッシュ値の3bitを使いますが、Swisstable (SSEなし)の場合8要素のグループ内の検索に7bitを使っているので、そもそも衝突の頻度は4bit分少ないのです。 false positiveにより、true positiveの隣のバイト限定で1bitの違いが無視されますが、それでも6bit分は有効に照合されているので、古典的なハッシュテーブルよりはずっと「ハッシュ値の一部しか一致しない要素に対する比較」は減っているのです。

なお、EMPTY (0x80) と DELETED (0xFE) は両方とも最上位bitが立っているので、 false positive を起こしません。なので、EMPTYやDELETEDに該当する要素を見に行って、未初期化メモリやdangling pointerを参照することは避けられます。

Bit Twiddling Hacks

Swisstable実装で使われていたビット操作は、参考陸としてこのサイトが紹介されていました。

Bit Twiddling Hacks

x86以外でも速く動作するプログラムを書きたい場合は、どんなハックがあるのか目を通しておくのがいいかもしれません。

このブログに乗せているコードは引用を除き CC0 1.0 で提供します。