base64编码,即将一段二进制中,每6bits用一个字符表示,以将二进制转换成易读易传输的字符串。6bits正好对应64个字符,故为base64。
TB64avx2库,或者叫Turbo64avx2库,在github很容易找到,是借助于intel avx2指令集,实现base64的编码和解码,速度非常快,查表法通常已经是优化极限了,而使用SIMD指令的实现比查表法还要快8倍。
当前该库不支持url safe的base64编码和解码,故而『被迫』剖析其源码,以实现url safe的支持。
小注:何为url safe编码?标准base64使用的字符集为:
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/
而为了让base64的编码能用在url里,字符集换为:
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_
就是末尾两个字符换掉了,以避免与url里空格和分隔发生冲突。
源码考虑了查表法和多种SIMD指令的实现办法,包括ssse和avx和avx2和avx512,不过好像原作者的实现并不完善,代码有好些问题,至少代码里没有avx512的实现,不过,查询表法和avx2法是完善的,这里也只考虑这两种办法。通常intel的cpu也主要支持ssse3和avx2,而像avx512,我查了手头的好几台服务器,也都不支持。
// Memory efficient (small lookup tables) scalar but (slower) version
size_t tb64senc( const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64sdec( const unsigned char *in, size_t inlen, unsigned char *out);
// Fast scalar
size_t tb64xenc( const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64xdec( const unsigned char *in, size_t inlen, unsigned char *out);
// ssse3
size_t tb64sseenc( const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64ssedec( const unsigned char *in, size_t inlen, unsigned char *out);
// avx
size_t tb64avxenc( const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64avxdec( const unsigned char *in, size_t inlen, unsigned char *out);
// avx2
size_t tb64avx2enc(const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64avx2dec(const unsigned char *in, size_t inlen, unsigned char *out);
// avx512
size_t tb64avx512enc(const unsigned char *in, size_t inlen, unsigned char *out);
size_t tb64avx512dec(const unsigned char *in, size_t inlen, unsigned char *out);
代码通过编译宏,来包含不同的指令头文件,从而识别当前cpu支持哪些simd指令,然后选择支持范围内的最快版本的代码进行执行,最快的也就是tb64avx2enc和tb64avx2dec。
#if defined(__AVX__)
#include <immintrin.h>
#define FUNPREF tb64avx
#elif defined(__SSE4_1__)
#include <smmintrin.h>
#define FUNPREF tb64sse
#elif defined(__SSSE3__)
#ifdef __powerpc64__
#define __SSE__ 1
#define __SSE2__ 1
#define __SSE3__ 1
#define NO_WARN_X86_INTRINSICS 1
#endif
#define FUNPREF tb64sse
#include <tmmintrin.h>
#elif defined(__SSE2__)
#include <emmintrin.h>
#elif defined(__ARM_NEON)
#include <arm_neon.h>
#endif
先来熟悉一下查询法的base64编码算法,算法在turbob64c.c中,函数名叫tb64senc,这里的s可能是simple/slow/small吧,代码十几行,非常简单,此处不多说。
查表法解码代码为turbob64d.c中的tb64sdec()函数,代码也不多,反向查表,表就要大一些,范围为0到256。特别注意,这里有两个62和63,表示同时支持+/和-_的解码,而不需要考虑是否url safe,接下来的avx2相关代码也是要达到这一目的:
#define _ 0xff // invald entry
const unsigned char lut[] = {
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _,62, _,62, _,63,
52,53,54,55,56,57,58,59,60,61, _, _, _, _, _, _,
_, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,
15,16,17,18,19,20,21,22,23,24,25, _, _, _, _,63,
_,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,
41,42,43,44,45,46,47,48,49,50,51, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _,
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _
};
#undef _
特别注意两点:
1. 在tb64avx2enc和tb64avx2dec的实现里,不能对齐32个字节的部分还是要回归到查询表法来处理。32个字节,也就是256bit,正好是avx2指令操作的单元类型__m256i。
2. base64的解码有是否check的问题,作者用B64CHECK宏来控制,可以编码时指定是否进行check,但我发现就算指定了check,也只是检查每256bit首字节是否为base64编码的合法字符,并不能完全检查,当然实测发现完全check的代价很高,很影响性能。个人觉得没有必要check,因为实际生产环境中,base64本身不具有自我纠错能力,应该在解码完成之后,手动检查其结构是否符合合法结构,比如是否为jpg图片,是否为json等。只是简单的字符集合法性检验并不能保证内容没被篡改。
avx2直接操作256bits,这也就是它快的原因,罗列几个相关指令:
使用32个uint8_t整型来直接初始化一个256bits的寄存器,举例:
__m256i cpv = _mm256_set_epi8( -1, -1, -1, -1, 12, 13, 14, 8, 9, 10, 4, 5, 6, 0, 1, 2,
-1, -1, -1, -1, 12, 13, 14, 8, 9, 10, 4, 5, 6, 0, 1, 2);
从内存中加载数据来填充一个寄存器,举例:
__m256i iv0 = _mm256_loadu_si256((__m256i *)ip);
相反的将寄存器里的数据存储到内存中,寄存器要比内存快得多,所以应该尽量减少从内存中加载数据和存储数据到内存,举例:
_mm256_storeu_si256((__m256i*) op, v0);
每32bit作为一块,将每块向右进行位移,左边用0补齐,当然如果移动位数大于31位,就全为0了。intel的说明书有伪码:
__m256i _mm256_srli_epi32 (__m256i a, int imm8)
#include <immintrin.h>
Instruction: vpsrld ymm, ymm, imm8
CPUID Flags: AVX2
operation:
FOR j := 0 to 7
i := j*32
IF imm8[7:0] > 31
dst[i+31:i] := 0
ELSE
dst[i+31:i] := ZeroExtend32(a[i+31:i] >> imm8[7:0])
FI
ENDFOR
dst[MAX:256] := 0
计算两个256bits的平均值,也就是一个指令下去,一次性可以计算32个int8_t的平均值,这就是simd的威力,举例:
const __m256i delta_hash = _mm256_avg_epu8(tmp, shifted);
这个我感觉是所有指令中的重中之重,是算法实现的基石,因为它提供了并行查表的能力。它如果用C语言来写就是:
//__m256i _mm256_shuffle_epi8(__m256i a, __m256i b);
for (i = 0; i < 16; i++){
if (b[i] & 0x80){
r[i] = 0;
}
else
{
r[i] = a[b[i] & 0x0F];
}
}
也就是以a为查询表,按b的每字节中的低4位来查表,所以只能索引到表a中的0到15范围,一次指令可以查询16个int8_t。特别注意:当最高位为1时,对应查表结果总为0.
这个没啥好说的,就是两个寄存器的值想加,返回一个新的寄存器,也就是一次可以完成32个小整数的相加。如果相加有进位,丢弃进位。
代码不多,但是很精妙:
const unsigned char *ip = in;
unsigned char *op = out;
size_t outlen = TB64ENCLEN(inlen);
...
const __m256i shuf = _mm256_set_epi8(10,11, 9,10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1,
10,11, 9,10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1);
for(; op <= (out+outlen)-(64+OVD); op += 64, ip += (64/4)*3) {
__m256i v0 = _mm256_castsi128_si256( _mm_loadu_si128((__m128i *) ip));
v0 = _mm256_inserti128_si256(v0,_mm_loadu_si128((__m128i *)(ip+12)),1);
__m256i v1 = _mm256_castsi128_si256( _mm_loadu_si128((__m128i *)(ip+24)));
v1 = _mm256_inserti128_si256(v1,_mm_loadu_si128((__m128i *)(ip+36)),1);
v0 = _mm256_shuffle_epi8(v0, shuf); v0 = mm256_unpack6to8(v0); v0 = mm256_map6to8(v0);
v1 = _mm256_shuffle_epi8(v1, shuf); v1 = mm256_unpack6to8(v1); v1 = mm256_map6to8(v1);
_mm256_storeu_si256((__m256i*) op, v0);
_mm256_storeu_si256((__m256i*)(op+32), v1);
}
...
_mm256_set_epi8()的32个小整数,前16个和后16个是一样的,这是因为作为table查询,_mm256_shuffle_epi8()只能查前16个,所以后16个无意义,只是作为填充而已。
base64编码是正好12个8bits正好对应16个6bits,shuf用来完成这种对应关系,会有重复bits,但不用担心,mm256_unpack6to8()完成消除多余bits和移位的工作:
static ALWAYS_INLINE __m256i mm256_unpack6to8(__m256i v) { /* https://arxiv.org/abs/1704.00605 p.12*/
__m256i va = _mm256_mulhi_epu16(_mm256_and_si256(v, _mm256_set1_epi32(0x0fc0fc00)), _mm256_set1_epi32(0x04000040));
__m256i vb = _mm256_mullo_epi16(_mm256_and_si256(v, _mm256_set1_epi32(0x003f03f0)), _mm256_set1_epi32(0x01000010));
return _mm256_or_si256(va, vb);
}
这里通过位与来消除多余bits,通过乘法来实现移位,不多解释。
mm256_map6to8()是编码的重点,它实现将0转成A,1转成B,2转成C,如此类推:
static ALWAYS_INLINE __m256i mm256_map6to8(const __m256i v) { /*map 6-bits bin to 8-bits ascii (https://arxiv.org/abs/1704.00605) */
__m256i vidx = _mm256_subs_epu8(v, _mm256_set1_epi8(51));
vidx = _mm256_sub_epi8(vidx, _mm256_cmpgt_epi8(v, _mm256_set1_epi8(25)));
const __m256i offsets =
_mm256_set_epi8(0, 0, -16, -19, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, 71, 65,
0, 0, -16, -19, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, 71, 65);
return _mm256_add_epi8(v, _mm256_shuffle_epi8(offsets, vidx));
}
如果有一个64x8bits的大表让我们查,那就万事大吉了,然后最多只支持查16x8bits的表。不过幸运的是base64的码表是有一定规律的,0到25和A到Z之间总是相差65,26到51和a到z之间总是相差71,52到61和'0'到'9'之间总是相差-4,这三个区段都是连续的,不连续的就是62和'+'相差-19,63和'/'相差-16。我们要支持的url safe的话,62和'-'相差-17,63和'_'相差32.
所以意思就比较明显了,对于输入向量v,我们将0到25都转成0xf,将26到51都转成0xe,这样就正好索引到65和71,再加回到向量v,就实现了0向A,1向B,2向C的转换。因为容量足够,这里作者没有考虑'0'到'9'的映射,所以offsets里有10个-4,接着就是'+'和'/'。所以需要支持safe url编码,只需要将上面的-19改成-17,-16改成32,即可。
说完编码再来说解码,解码要有挑战得多。如下是核心代码:
const unsigned char *ip;
unsigned char *op;
...
const __m256i delta_asso = _mm256_setr_epi8(0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0f,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0f);
const __m256i delta_values = _mm256_setr_epi8(0x00, 0x00, 0x00, 0x13, 0x04, 0xbf, 0xbf, 0xb9, 0xb9, 0x00, 0x10, 0xc3, 0xbf, 0xbf, 0xb9, 0xb9,
0x00, 0x00, 0x00, 0x13, 0x04, 0xbf, 0xbf, 0xb9, 0xb9, 0x00, 0x10, 0xc3, 0xbf, 0xbf, 0xb9, 0xb9);
__m256i cpv = _mm256_set_epi8( -1, -1, -1, -1, 12, 13, 14, 8, 9, 10, 4, 5, 6, 0, 1, 2,
-1, -1, -1, -1, 12, 13, 14, 8, 9, 10, 4, 5, 6, 0, 1, 2);
for(ip = in, op = out; ip < (in+inlen)-(64+OVD); ip += 64, op += (64/4)*3) {
__m256i iv0 = _mm256_loadu_si256((__m256i *)ip);
__m256i iv1 = _mm256_loadu_si256((__m256i *)(ip+32));
__m256i ov0,shifted0; MM256_MAP8TO6(iv0, shifted0, delta_asso, delta_values, ov0); MM256_PACK8TO6(ov0, cpv);
__m256i ov1,shifted1; MM256_MAP8TO6(iv1, shifted1, delta_asso, delta_values, ov1); MM256_PACK8TO6(ov1, cpv);
_mm_storeu_si128((__m128i*) op, _mm256_castsi256_si128(ov0));
_mm_storeu_si128((__m128i*)(op + 12), _mm256_extracti128_si256(ov0, 1));
_mm_storeu_si128((__m128i*)(op + 24), _mm256_castsi256_si128(ov1));
_mm_storeu_si128((__m128i*)(op + 36), _mm256_extracti128_si256(ov1, 1));
}
...
这里一开始就创建了三个向量,让人很费解,但是大概流程还是很容易猜到,MM256_MAP8TO6()就是将'A'转成0,'B'转成1等等,MM256_PACK8TO6()是想办法位移,去掉每个字节前空白的两个bits,以得到最终解码结果。这其中的MM256_MAP8TO6()最为复杂,两个常向量delta_asso和delta_values也是给它用的。然后它的代码却非常简洁,如下:
#define MM256_MAP8TO6(iv, shifted, delta_asso, delta_values, ov) { /*map 8-bits ascii to 6-bits bin*/\
shifted = _mm256_srli_epi32(iv, 3);\
const __m256i delta_hash = _mm256_avg_epu8(_mm256_shuffle_epi8(delta_asso, iv), shifted);\
ov = _mm256_add_epi8(_mm256_shuffle_epi8(delta_values, delta_hash), iv);\
}
将上面的代码翻译成python代码大概是这样的:
delta_asso = (0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0f,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0f)
delta_values = (0x00, 0x00, 0x00, 0x13, 0x04, 0xbf, 0xbf, 0xb9, 0xb9, 0x00, 0x10, 0xc3, 0xbf, 0xbf, 0xb9, 0xb9,
0x00, 0x00, 0x00, 0x13, 0x04, 0xbf, 0xbf, 0xb9, 0xb9, 0x00, 0x10, 0xc3, 0xbf, 0xbf, 0xb9, 0xb9)
s = b'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/-_-_'
def srli32(a,b,c,d):
t = (a << 0) + (b << 8) + (c << 16) + (d << 24)
t >>= 3
return (t >> 0) & 0xff, (t >> 8) & 0xff, (t >> 16) & 0xff, (t >> 24) & 0xff
tmp = [srli32(s[i], s[i+1], s[i+2], s[i+3]) for i in range(0, len(s), 4)]
shifts = []
[shifts.extend(i) for i in tmp]
print('shifts:', shifts)
#print([i & 0xf for i in s])
asso_res = [delta_asso[i & 0xf] for i in s]
print('asso_res:', asso_res)
hash = [(asso_res[i] + shifts[i] + 1) >> 1 for i in range(len(s))]
print('hash:', hash)
hash_vals = [delta_values[i & 0xf] for i in hash]
print('hash_vals:', hash_vals)
res = [hash_vals[i] + s[i] for i in range(len(s))]
print('res:', res)
raw = [(hash_vals[i] + s[i]) & 0xff for i in range(len(s))]
print('raw:', raw)
先不看代码,说说我们要达到的目的,我们要将'A'映射成0,'B'映射成1... 但显然跟编码一样,我们需要利用A-Z和a-z的连续性规律,这样才能在16个8bits条件实现转换。
再从最后一句ov = _mm256_add_epi8(_mm256_shuffle_epi8(delta_values, delta_hash), iv); 和 delta_values 的特点出发,我们最终还是要通过查delta_values表,来得到iv的补数,求得最终结果。0xb9 = 185,而256 - 185正好是71,0xbf = 191,而256 - 191正好是65,正是'a'-'z'区段和'A'-'Z'区段所对应的差值,所以目前感觉是对的。
那么接下来问题来了,代码中的先移位_mm256_srli_epi32,再查表和移位的结果求均值_mm256_avg_epu8是干用的?答案是hash,这里的delta_asso应该是hash时的修正盐值。
现在有点头大了,如果我们希望加入'-'和'_'的支持,显然当前的hash算法就不成立了,这就不是简单修改一个值那么简单了,我们需要建立新的hash关系。
接着,这么来分析这个问题:我们将数据分类,一共7类,A-Z为类1,a-z为类2,'0'-'9'为类3,'+'为类4,'/'为类5,'-'为类6,'_'为类7。hash的bucket容量为16,我们要将同一类的元素hash到同一位置或者不同位置,如果是不同位置,那么delta_values里的取值就必须一样;而不同类的元素之间不能hash冲突,如果冲突,hash就失败,就调整盐值向量delta_asso的值,继续尝试。而我们首先就是需要找到一个满足条件的盐值向量来!
再注意一点,当我们右移三位时,无论当前是哪个字符,也无论它在四字节块中的哪个位置,右移之后的结果只有八种情形,也就是填充过来的3bits,有6种组合:[t | 0x0, t | 0x20, t | 0x40, t | 0x60, t | 0x80, t | 0xa0, t | 0xc0, t | 0xe0]。所以,每个字符我们在验证hash时,需要校验八种。
搜索时的检验代码为:
def validate(n, salt, tbl, grp):
t = n >> 3
for i in (t | 0x0, t | 0x20, t | 0x40, t | 0x60, t | 0x80, t | 0xa0, t | 0xc0, t | 0xe0):
idx = (i + salt[n & 0xf] + 1) >> 1
if idx & 0x80:
return False
idx &= 0xf
if tbl[idx] > 0 and tbl[idx] != grp:
return False
else:
tbl[idx] = grp
return True
def candidate(salt):
tbl = [0] * 16
for i in b'+':
if not validate(i, salt, tbl, 4):
print('fail')
return False
print(tbl)
for i in b'/':
if not validate(i, salt, tbl, 5):
print('fail')
return False
print(tbl)
for i in b'-':
if not validate(i, salt, tbl, 6):
print('fail')
return False
print(tbl)
for i in b'_':
if not validate(i, salt, tbl, 7):
print('fail')
return False
print(tbl)
for i in b'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
if not validate(i, salt, tbl, 1):
print('fail')
return False
print(tbl)
for i in b'abcdefghijklmnopqrstuvwxyz':
if not validate(i, salt, tbl, 2):
print('fail')
return False
print(tbl)
for i in b'0123456789':
if not validate(i, salt, tbl, 3):
print('fail')
return False
print(tbl)
return True
salt是16长的list,是需要检验的盐值向量,tbl用来标记hash是否冲突,如果冲突,立马校验失败。特别注意:+/-_四个字符最容易引起冲突,所以这里是优先进行hash计算。
okay,接下来就是暴力搜索了。
不过可以稍微剪枝,有几点可以明确:
1. 将7类hash到16个桶里,很大概率会有很多组合适的结果,所以不必从[0x00] * 16,一直尝试到[0xff] * 16,这鬼知道要算多久。而且考虑均值计算的特点,查询也只取低4位,最差情况下我们也只需要尝试从[0x00] * 16 到 [0x20] * 16。
2. 基于如下原因,salt向量的第11、第13、第15,三个取值必不相等,否则就会冲突。
#In [1]: [n & 0xf for n in b'+/-_']
#Out[1]: [11, 15, 13, 15]
#In [1]: [n >> 3 for n in b'+/-_']
#Out[1]: [5, 5, 5, 11]
基于以上原因,我们基于原来的delta_asso,优先从第11、13、15三个元素的调整开始搜索:
def run():
salt = [0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0f]
for i in range(0x0, 0x20 + 1):
for j in range(0x0, 0x20 + 1):
for k in range(0x0, 0x20 + 1):
print('-------------')
salt[11] = i
salt[13] = j
salt[15] = k
if candidate(salt):
print('success: ', salt)
return
将python搜索代码改写成C语言,这样可以大大提高效率。实际的计算量是非常巨大的,如果希望找到一个同时支持url safe base64和standard base64的盐值矩阵,就需要遍历0x20 ^ 16 种组合,7类元素hash到16个桶,可能存在,也可能根本就不存在。实际执行了一个多小时,在测试了0x10 ^ 8 种组合,即将后8个字节的全排列进行了尝试没有找到后,放弃该最优方案,改为url safe base64一组,standard base64另一组。
注意:实际测试时发现,+/-_四个字符,任意去掉任一个,都能非常容易找到一个合适的盐值向量。
很快就得到搜索结果:
url safe[no support '/']: 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 13, 0, 0, 0, 12
standard[no support '_']: 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 14, 0, 12, 0, 0
有了盐值向量,一切就解决了,接着就是凑数了,让向量相加之后得到[0,1,2,3,4....,62,63,62,63,62,63],于是乎代码修改为:
const __m256i delta_asso = url_safe ?
_mm256_setr_epi8(1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 13, 0, 0, 0, 12,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 13, 0, 0, 0, 12) : /* no support '/' */
_mm256_setr_epi8(1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 14, 0, 12, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 14, 0, 12, 0, 0); /* no support '_' */
const __m256i delta_values = url_safe ?
_mm256_setr_epi8(0, 0, 0, 17, 4, -65, -65, -71, -71, 19, 0, -65, -32, -71, 0, 0,
0, 0, 0, 17, 4, -65, -65, -71, -71, 19, 0, -65, -32, -71, 0, 0) :
_mm256_setr_epi8(0, 0, 0, 16, 4, -65, -65, -71, -71, 17, 19, -65, -65, -71, -71, 0,
0, 0, 0, 16, 4, -65, -65, -71, -71, 17, 19, -65, -65, -71, -71, 0);
搞定,以上就是对turbob64avx2的剖析!