这两天研究手搓GEMM,算swizzle给我算烦了,结果发现这个任务比我想象中要简单得多,但是没找到线程的轮子,实在受不了写了个C++库自动计算swizzle布局:melonedo/algebraic-layouts

背景

共享内存的 bank 划分

在英伟达的 GPU 中,为了支持灵活的共享内存访问,将共享内存划分为了 32 个 bank。每次发起共享内存事务(transation)时,可以从这 32 个 bank 中分别读取一个 32 位数据。以 32 位的字为单位索引,则 bank 以地址的低 5 位进行划分,与高位没有关系。每次发起共享内存请求时,可以任意访问这 32 个 bank 中每个 bank 的内容。32 对应 CUDA 中一个 warp 的线程数量,也是一次发起内存请求的最大数量。如果一次请求的 32 个地址没有均匀地分布在 32 个 bank 中,根据抽屉原理,则有一些 bank 对应了多组数据,称为产生了 bank conflict,需要多次发起共享内存事务才能读取到需要的信息,造成了带宽的浪费。更为重要的是,共享内存的吞吐并不高,一个 SM 中一个周期只能读取 32 个 bank 的内容各 4 字节(这个数字在很老的架构上已经固定下来),而进行 FMA 等计算时 SM 的吞吐在新一点的架构上可以达到每周期 128 个操作。

在高性能计算相关的应用中需要注意,如果软件中共享内存访问不是以 32 位为单位,而是以 64 位或是 128 位的向量化形式发起,执行时仍然是会拆分成若干个 32 位的请求,例如 32 个线程分别发起 128 位的请求,会按顺序,每 8 个线程为一组,进行 4 次内存访问,即 0-7、8-15、16-23、24-31号线程分别发起一次请求。发起的这 4 次请求各自独立,不会因为有 bank 冲突而智能地进行组合。在本文中,类似的向量化访问时,分析的是一次请求的内容,也就是对应这 4 次内存访问中的其中 1 次。如果按照这 4 个请求的整体进行分析,则访问同样的布局,会因为分配的线程序号不同,而发起不同的共享内存请求,硬件的访问模式也大不相同。

另外,一个 warp 中的 32 个线程也可以发起重复的地址的请求,此时会在硬件上自动合并重复的请求。

Swizzle 布局

上述 32 个 bank 的限制看似灵活,但是在软件实现上并不简单。32 个 bank 的划分方式,显然适合于 32 个线程访问数组中连续的 32 个元素的情况。然而,CUDA 的程序中有很多地方设计非连续的访问。例如,进行矩阵操作时,32 个线程一次只能读取一个大矩阵中的一小块,可能读取连续的一行,也可能读取连续的一列,而很多情况下会读取其中的 MxN 的一个小块。

然而,上述复杂读取的需求并不能简单地改变矩阵存储的布局,这是因为相应的数据通常需要支持两种不同的访问模式。一般来说会按照顺序连续写入共享内存,但是读取时无法和写入时保持相同的顺序。

为了解决上述问题,需要设计特殊的地址计算方式,使得共享内存中的地址和矩阵的行、不再是简单的乘法关系,这些方法中最通用的是基于异或的 swizzle 布局。Swizzle 布局可以支持连续访问和一种特定的不连续访问模式,因此可以解决大部分共享内存读写访问模式不同的需求。

Swizzle 布局在 CuTe 中定义为:

// A generic Swizzle functor
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
 *                               ^--^ MBase is the number of least-sig bits to keep constant
 *                  ^-^       ^-^     BBits is the number of bits in the mask
 *                    ^---------^     SShift is the distance to shift the YYY mask
 *                                       (pos shifts YYY to the right, neg shifts YYY to the left)
 *
 * e.g. Given
 * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
 * the result is
 * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
 */
template <int BBits, int MBase, int SShift = BBits>
struct Swizzle{/*...*/};

含义是根据 B M S 三个模板参数,将 YY 移动到 ZZ 的位置,并进行异或。然而,上述的定义仅仅说明了具体的计算方法,并没有说明 swizzle 布局如何解决共享内存的 bank 冲突。因此,本文将就各种应用场景,说明如何计算出合适的 swizzle 布局参数。

经典访问模式:按列读取矩阵

在各种访问模式中,最基本的访问模式是按照列来读取一个矩阵,例如一个 8x8 的矩阵,我们需要找到一个方法,既可以高效地访问矩阵的一行,也可以高效地访问矩阵的一列。这个访问模式使用于各种需要在软件上对矩阵进行转置的情况。

由于 32 种颜色可视化难度较大,本文中都默认共享内存组织为 8 个 bank。假设读者已经正确掌握了一次共享内存请求的概念,每次要读取的内存就是对应一次共享内存请求的内容,不会被硬件进一步拆分,因此可以认为我们就需要一次共享内存事务中读取所有的内容。然而,如果我们使用简单的p = row * 8 + col布局,则一列的内容都分布在同一个 bank,读取一个 8x8 的矩阵,会产生 8 路 bank 冲突。

no swizzle

对于 Swizzle 有基本了解的话应该可以知道,这个场景下的计算非常简单,直接将列坐标从简单的col,转而使用的row * 8 + (row ^ col),或者是p ^ (p >> 3)(CuTe 中不区分行和列,用上述p = row * 8 + col的结果进行进一步的计算)计算即可得到。实际上矩阵可能不只 8 行,但我们可以把row中超过 8 的部分取模,将问题归一到 8 行内。那么,考虑到行的长度也是 8 ,那么完整的地址计算公式是p ^ ((p & (7 << 3)) >> 3),对应布局Swizzle<3, 0, 3>,也就是说,这种情况下 B=S=log(M), M=0。

Swizzle<3,0,3>

用线性布局的角度来看,这个布局是

[100100010010001001000100000010000001]\begin{bmatrix} 1 & 0 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 0 & 1 & 0 \\ 0 & 0 & 1 & 0 & 0 & 1 \\ 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

当按列访问时,访问的是后 3 列[100010001100010001]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix},对应的前三行是 bank 对应的位。前三行行满秩,因此访问高位时对应 bank 的部分张成完整的 8 个 bank 空间。

简单的变形:加宽

前面已经处理了矩阵行数增加时的情况,那么一个很自然的问题是,如果我们保持上述的访问模式不变,想要加宽矩阵怎么办呢?例如,如果矩阵不只有 8 列,而是有 32 列,那么会发生什么呢?如果用线性布局去思考,或者直接去思考这个布局的二进制表达,很容易想到这个情况下我们布局的计算公式仍然是row ^ col,只不过row对应的位数移动了两位,并且需要限制row只使用低 3 位,因此用p表达的公式是p ^ ((p & (7 << 5)) >> 3),对应Swizzle<3, 0, 5>。据此我们可以得出,将矩阵加宽,只需要增加 S 即可。

Swizzle<3,0,5>

对应的线性布局是 [1000010001000010001000010001000000001000000001000000001000000001]\begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

每次访问一列,对应[100010001000000100010001]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix},可以看到低位和 8x8 矩阵时是完全相同的。

不简单的变形:缩窄

可以加宽矩阵,那么自然就可以缩窄,例如将矩阵改为 8x4。有了前面加宽的经验,我们自然可以想到,这种情况下只需要减少 S 即可。事实也是如此,这样我们可以得到布局Swizzle<3,0,2>

Swizzle<3,0,2>

但是!!CuTe 中不允许这么做,程序里有一行:

static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");

这个限制可能是由于布局代数计算的原因而作出的限制。实际上,我们仍然想要套用用前面的row * 8 + (row ^ col)公式会发现,此时row只有两位了,剩下的一位是row本身做了变换。这个布局只适合于用p ^ ((p & (7 << 2)) >> 3)表示。用线性布局看看:

[1010001010001010001000001]\begin{bmatrix} 1 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 1 & 0 \\ 0 & 0 & 1 & 0 & 1 \\ 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

在访问一列时取[100010101010001]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 1 & 0 & 1 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix}

可以明确地看到,第 3 列是row的第一位,他在运算中和第 5 列,也就是row的第 3 位做了计算。这个异或对于这个访问模式是可有可无的。

注意,这个布局访问 4 列 2 行、2 行 4 列时均没有 bank 冲突。也就是说,限制在 2 的幂范围内,访问任意的分块都无 bank 冲突。

既然 CuTe 不支持,那有什么办法呢?再仔细看看上面的布局可以发现,row的第一位本身就属于 bank 的前三行的范畴,可以不参与 swizzle 运算,得到下列的布局。

[1001001001001000001000001]\begin{bmatrix} 1 & 0 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 & 1 \\ 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

这个布局在访问一列时,取[010001100010001]\begin{bmatrix} 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix}

可以看到前三行满秩,同样没有 bank 冲突。对应的布局并不是减少 S,而是减少了 B,得到Swizzle<2, 0, 3>

Swizzle<2,0,3>

简单的变形:访问多列

一个同样常见的场景是我们需要同时访问矩阵的多列。例如在保持同时请求的数据量不变的情况下,可以一次请求 4 行 2 列的矩阵。

很容易发现,这个情况下实际上等价于一个 4 行 4 列的矩阵一次访问一列,只不过是“列”的概念需要稍微修改。即考虑到是在两行当一行用,在Swizzle<2,0,2>的基础上,改为Swizzle<2,1,2>

Swizzle<2,1,2>

对应线性布局是在左上加了个一

[100000010100001010000100000010000001]\begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

注意到列的第三位实际上没有变化,整个访问模式是以 4 行为周期。

实用的变形:隔行访问

有时候想要把大于一个 bank 的内容分配给一个线程,那就要考虑具体的元素的归属了。如果在上面的布局的基础上,将多行并作一行,则可以实现每次访问不连续的行。在参数上,在Swizzle<2,1,2>中增加 S 即可,得到Swizzle<2,1,2>

Swizzle<2,1,3>

对应线性布局只是把添加的异或的部分往右移动。 [100000010010001001000100000010000001]\begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 1 & 0 \\ 0 & 0 & 1 & 0 & 0 & 1 \\ 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}

实用的变形:非 2 的幂

然而,实际上计算中用到的矩阵大小通常是受到具体的硬件限制,很难分配得恰好符合 2 的幂。例如,算法中最合适的矩阵大小可能是 8x24,并且仍然需要访问 1 行 8 列或者是 8 行 1 列的内容。怎么办呢?实际上答案是完全不需要做任何操作,只需要按照一行长度是 2 的幂时我们的row ^ col方法进行处理即可,只不过行的长度是 24,所以计算公式为row * 24 + (row ^ col)。此时显然已经无法使用p表示了。

Swizzle<3,0,3>-24

看图可以发现,这个情况下只是 8x8 的情况在横向再复制了两次。

特殊的变形:除不尽的情况

注意到前面加宽矩阵总是简单的,而缩窄矩阵则较为特殊。在行长非 2 的幂的情况下,如果一行的长度可以被同时访问的列数整除,则只需要把访问模式在横向进行复制。那么如果不能呢?先看看此时的布局:

8x9

例如矩阵的形状为 8x9,想要访问 4 行 2 列,则对col做异或无法完全消除 bank conflict了。天无绝人之路,前面既然可以一行作两行,那现在也可以两行作一行,把这个矩阵当成 4x18,那就完全没问题了。

swizzle2-9

讨论

上面的分析中,我借助了线性布局来进行分析,但是具体计算的过程中并不需要线性布局也可以推导出来,计算也并不算复杂,适宜作为一个头文件存在。只可惜没有现成的轮子,不知道这个功能在哪里有前任的实现。

可以看到,CuTe 中的 Swizzle 布局只是一个计算时的定义式,很不容易直接从中读取出这个布局的含义,这也是为什么 Swizzle 布局一直认为非常难懂的原因之一。