一种 TP-SP-EP 混合并行策略

2758 字
7 min read

一、两种混合并行图示

非完整图示,不含后续 EP 并行 MoE 层。


二、并行原理解析

2.1 前提:qkv_inear (列切)

两种方案都始于一个列并行 (Column-Parallel)qkv_inear 层。

  • 我们有 NN 个 GPU。
  • 输入 XX 是复制的 (replicated)。
  • 第一个 qkv inear 层的权重 AA 被按切分:A=[A1,A2,,AN]A = [A_1, A_2, \dots, A_N]
  • GPU ii 计算:Yi=GeLU(XAi)Y_i = \text{GeLU}(X A_i)
  • 关键状态:计算完成后,中间激活 Y=[Y1,,YN]Y = [Y_1, \dots, Y_N]NN 个 GPU 上是按隐藏层维度HdimH_{\text{dim}} 维度,也常称为 KK 维度)切分的。

这里涉及到 Attention 的 TP 并行,原理可参考猛猿大佬文章 https://zhuanlan.zhihu.com/p/622212228,不再赘述。现在,我们要计算第二层 Z=YWZ = YW,其中 WWout_linear 的权重。

2.2 方案一:out_linear (行切) + all_reduce + Slice

这个方案的核心思想是:保持 HdimH_{\text{dim}} 维度的切分

  1. 数据排布:- 输入 (YY)Y=[Y1,,YN]Y = [Y_1, \dots, Y_N] (按 HdimH_{\text{dim}} 切分)。- 权重 (WW)out_linear 权重 WW 必须同样HdimH_{\text{dim}} 维度(即)切分: W=[W1W2WN]W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_N \end{bmatrix}
  2. out_linear (局部计算)
    • GPU ii 拥有 YiY_iWiW_i
    • 它只能计算它所拥有的那部分乘积:Zi=YiWiZ_i = Y_i W_i
  3. all_reduce (通信)
    • 根据矩阵乘法,最终结果是 Z=YW=i=1NYiWiZ = YW = \sum_{i=1}^N Y_i W_i
    • all_reduce 操作在所有 GPU 之间对 ZiZ_i 进行求和。
    • AllReduce({Z1,,ZN})Z\text{AllReduce}(\{Z_1, \dots, Z_N\}) \to Z
  4. 完整的 Z
    • ZZ 在所有 GPU 上都是完整的、复制的 (replicated)。
  5. 每张 GPU 从 Z 的行维度平均 Slice 出一部分,转为序列并行送到下一层。
  • 优点
    • 节省内存:每个 GPU 只需要存储 1/N1/NWW 权重。这在权重(如 WprojW_{\text{proj}})非常大时至关重要。
  • 缺点
    • 通信瓶颈:必须在计算 ZiZ_i 之后执行一个 all_reduce。这是一个同步操作,通信量为 ZZ 的大小,可能会阻塞流水线。

2.3 方案二:all2all + out_linear (不切分)

这是 “张量并行 (TP) 切换到 序列并行 (SP)” 的策略。这个方案的核心思想是:通过通信改变数据的切分维度

  1. 数据排布
    • 输入 (YY)Y=[Y1,,YN]Y = [Y_1, \dots, Y_N] (按 HdimH_{\text{dim}} 切分)。
    • 权重 (WW)out_linear 权重 WW 不切分 (replicated)。每个 GPU 都有完整的 WW
  2. all2all (通信)
    • 这一步的目标是将 YY 的数据排布从“按 HdimH_{\text{dim}} 切分”转置为“按序列 (Sequence) 维度切分”。
    • 之前:GPU ii 拥有 YiY_i(形状 S×(Hdim/N)S \times (H_{\text{dim}}/N))。
    • 操作
      • GPU ii 将它的 YiY_i沿着 SS 维度切成 NN 块:Yi=[Yi(1)T,,Yi(N)T]TY_i = [Y_i^{(1)T}, \dots, Y_i^{(N)T}]^T
      • GPU iiYi(j)Y_i^{(j)} 发送给 GPU jj
      • GPU jj 收到来自所有 NN 个 GPU 的 {Y1(j),,YN(j)}\{Y_1^{(j)}, \dots, Y_N^{(j)}\}
    • 之后:GPU jj 将收到的块沿着 HdimH_{\text{dim}} 维度拼接起来(⚠️:这里会有一个 transpose 操作),得到 Y^j=[Y1(j),,YN(j)]\hat{Y}_j = [Y_1^{(j)}, \dots, Y_N^{(j)}](形状 (S/N)×Hdim(S/N) \times H_{\text{dim}})。
    • 结果YY 的排布从 NNS×(K/N)S \times (K/N) 的块(TP)转换成了 NN(S/N)×K(S/N) \times K 的块(SP)。
  3. out_linear (局部计算)
    • GPU jj 拥有 Y^j\hat{Y}_j (形状 (S/N)×Hdim(S/N) \times H_{\text{dim}}) 和完整的 WW (形状 Hdim×HdimH_{\text{dim}} \times H_{\text{dim}})。
    • 它计算 Zj=Y^jWZ_j = \hat{Y}_j W
  4. 最终结果
    • ZjZ_j 的形状是 (S/N)×Hdim(S/N) \times H_{\text{dim}}
    • 最终输出 ZZNN 个 GPU 上是按序列 (Sequence) 维度切分的。

工程实现上,它们是两种完全不同的并行范式,有着根本的取舍:

特性方案一 (out_linear [行切] + all_reduce)方案二 (all2all + out_linear [不切分])
策略标准行并行 (Row-Parallelism)张量并行 (TP) \to 序列并行 (SP) 转换
out_linear 权重 WW按行切分 (节省 N1/NN-1/N 内存)不切分/复制 (需要 NN 倍内存)
通信操作all_reduce (在计算之后)all2all (在计算之前)
通信内容输出 ZZ (形状 S×HdimS \times H_{\text{dim}})激活 YY (形状 S×HdimS \times H_{\text{dim}})
输出 ZZ 的排布复制的 (Replicated)按序列切分 (Sequence-Parallel)

结论:方案二牺牲了 WW 的内存(现在需要 NNWW),来换取将并行维度从 HdimH_{\text{dim}} (TP) 切换到 SS (SP),其主要目的是 all2all 替代 all_reduce,并利用通信-计算重叠来提升流水线效率。


三、通信量对比分析

通信量是决定这两种方案性能的关键因素。我们来详细分析一下,假设:

  • NN = GPU 数量 (TP 规模)
  • HdimH_{\text{dim}} = 隐藏层维度
  • dd = 数据类型大小 (例如 bfloat16 为 2 字节)

3.1 方案一:out_linear (行切) + all_reduce

  • 目标:计算 Z=ZiZ = \sum Z_i 并将 ZZ 分发回所有 GPU。
  • 通信对象:张量 ZiZ_i,其大小为 S×Hdim×dS \times H_{\text{dim}} \times d
  • 通信量分析
    • 在标准的 ring-allreduce 中,每个 GPU 在 N1N-1 步中发送数据,在 N1N-1 步中接收数据。
    • 为了完成求和与分发,每个 GPU 最终发送的总数据量约为 N1N×(S×Hdim×d)\frac{N-1}{N} \times (S \times H_{\text{dim}} \times d)接收的总数据量也约为 N1N×(S×Hdim×d)\frac{N-1}{N} \times (S \times H_{\text{dim}} \times d)
    • 每 GPU 的总通信量 (发送+接收):V1=2×N1N×(S×Hdim×d)V_1 = 2 \times \frac{N-1}{N} \times (S \times H_{\text{dim}} \times d)

3.2 方案二:all2all + out_linear (不切分)

  • 目标:将 YY 的切分方式从 HdimH_{\text{dim}} 维度 (TP) 转换为 SS 维度 (SP)。
  • 通信对象:张量 YiY_i,其大小为 S×(Hdim/N)×dS \times (H_{\text{dim}}/N) \times d
  • 通信量分析
    • all2all 操作中,每个 GPU ii 将其本地的 YiY_i (形状 S×(Hdim/N)S \times (H_{\text{dim}}/N)) 切分为 NN 块,每块 Yi(j)Y_i^{(j)} (形状 (S/N)×(Hdim/N)(S/N) \times (H_{\text{dim}}/N))。

    • GPU iiN1N-1 块发送给其他 N1N-1 个 GPU。

    • GPU ii 发送的总数据量为:(N1)×size(Yi(j))=(N1)×(SN×HdimN)×d(N-1) \times \text{size}(Y_i^{(j)}) = (N-1) \times (\frac{S}{N} \times \frac{H_{\text{dim}}}{N}) \times d

    • 同理,它也接收 N1N-1 块。

    • 每 GPU 的总通信量 (发送+接收):

      V2=2×(N1)×(S×HdimN2)×d=2(N1)N2×(S×Hdim×d)V_2 = 2 \times (N-1) \times (\frac{S \times H_{\text{dim}}}{N^2}) \times d = \frac{2(N-1)}{N^2} \times (S \times H_{\text{dim}} \times d)

对比

方案通信操作每 GPU 总通信量 ( V )
方案一all_reduceV1=2(N1)N×(SHdimd)V_1 = \frac{2(N-1)}{N} \times (S \cdot H_{\text{dim}} \cdot d)
方案二all2allV2=2(N1)N2×(SHdimd)V_2 = \frac{2(N-1)}{N^2} \times (S \cdot H_{\text{dim}} \cdot d)

得出:V2=1N×V1V_2 = \frac{1}{N} \times V_1。因此,方案二 (**all2all**) 在通信总量上具有明显优势。

除此之外,选择方案二还有其他的原因:

  1. 通信模式all_reduce 包含计算(Sum),而 all2all 只是数据交换(Transpose)。在某些硬件拓扑(如 NVLink Switch)上,all2all 几乎可以达到线速,效率极高。
  2. 通信重叠:方案二的 all2all 作用于 YY,它可以在 YY 被计算时重叠 (Overlap) 进行。方案一的 all_reduce 必须等待 out_linear 计算 ZiZ_i 完成后才能开始。
  3. 内存代价:方案二的优势是有代价的。它需要每个 GPU 都存储完整的 out_linear 权重 WW,而方案一只需要 1/N1/N 的权重。
  4. 序列并行 (SP):如果你的网络架构(例如 MoE EP 并行)被优化为在序列并行的输入上工作,那么方案二的输出(ZZ 按序列切分)可以直接喂给下一层,完全消除了后续对 ZZ 进行 all_reduce allgather 的需求

四、EP 并行的 MoE 层

4.1 假设一些参数:

  • 输入:XRS×HdimX \in \mathbb{R}^{S \times H_{\text{dim}}}
  • Router:为每个 token 选 kk 个专家(Top-k),得到
    • 专家索引:eid0,,E1S×k\text{eid} \in {0,\dots,E-1}^{S \times k}
    • 权重:wRS×kw \in \mathbb{R}^{S \times k}
  • 专家集合:共有 EE 个 experts
  • EP 规模:PP(同一个 EP group 中有 PP 张 GPU)
    • 每张 GPU 持有 E/PE/P 个 experts(参数分片)

两次 all2all**(Dispatch / Combine):**

text
EP 的本质是:**按专家维度切参数,但按 token 路由把激活在卡间重排**。因此每个 MoE 层固定两次集合通信。

- **Dispatch**:把 token 送到“持有目标专家”的 GPU
- **Combine**:把专家输出送回“token 所在的 GPU”,并按权重聚合
EP 的本质是:**按专家维度切参数,但按 token 路由把激活在卡间重排**。因此每个 MoE 层固定两次集合通信。

- **Dispatch**:把 token 送到“持有目标专家”的 GPU
- **Combine**:把专家输出送回“token 所在的 GPU”,并按权重聚合

MoE 输入的 SP 排布:

  • 全局 SS 个 tokens,被 PP 张 GPU 按序列维度均分
    • GPU ii 拥有:XiRSP×HdimX_i \in \mathbb{R}^{\frac{S}{P} \times H_{\text{dim}}}
  • 每个 GPU 本地计算 Router:
    • eidi0,,E1SP×k\text{eid}_i \in {0,\dots,E-1}^{\frac{S}{P} \times k}
    • wiRSP×kw_i \in \mathbb{R}^{\frac{S}{P} \times k}

4.2 Dispatch:permute + all2all(把 token 发到专家所在卡)

目标:将 token 从“按序列切分”的排布,变换为“按专家分桶并落在对应 GPU”的排布。

本地分桶(bucketize)/ 打包(pack):

  • 对 GPU ii 上的每个 token tt,它会被路由到 kk 个专家:eidi[t,1],,eidi[t,k]{\text{eid}_i[t,1], \dots, \text{eid}_i[t,k]}
  • 定义专家到 GPU 的映射(静态):
    • owner(e)0,,P1\text{owner}(e) \in {0,\dots,P-1}
  • GPU ii 将其本地 token 复制出 kk 份“token-expert 关联样本”,并按 owner(e)\text{owner}(e) 分桶:
    • 形成 PP 个发送缓冲区:sendbufij\text{sendbuf}_{i\to j}
  • 同时,GPU ii 记录两类索引用于还原:
    • src_slot:这个样本来自本地第几个 token
    • k_slot:这是 top-k 的第几路(用于乘权重)

第一次all2all

  • 所有 GPU 同时执行 all2allrecvbufj=i=0P1sendbufij\text{recvbuf}_{j} = \bigcup_{i=0}^{P-1} \text{sendbuf}_{i\to j}

Dispatch 后的数据排布:

  • GPU jj 得到按其本地 experts 分桶后的激活集合:
    • X^jRMj×Hdim\hat{X}_j \in \mathbb{R}^{M_j \times H_{\text{dim}}}
  • 其中 MjM_j 是路由到 GPU jj 的(token, expert)样本数(一般不均匀)。
  • 同时携带对应的还原元信息(如 src_slot / k_slot、以及回传路由所需的 index)。

    关键状态:Dispatch 后,激活不再保持原序列顺序,而是按专家分桶组织,便于专家侧批处理。

4.3 Experts:本地 grouped_gemm(只在持有的专家上算)

GPU jj 持有专家集合,每个专家是一个 FFN:

  • Expert ee 的参数:W1(e),W2(e)W^{(e)}_1, W^{(e)}_2
  • 对属于该专家的子 batch:X^j(e)\hat{X}^{(e)}_j
  • 局部计算:
    • H(e)=ϕ(X^j(e)W1(e))H^{(e)} = \phi(\hat{X}^{(e)}_j W^{(e)}_1)
    • O(e)=H(e)W2(e)O^{(e)} = H^{(e)} W^{(e)}_2

将所有专家输出拼接为:

  • O^jRMj×Hdim\hat{O}_j \in \mathbb{R}^{M_j \times H_{\text{dim}}}(与 X^j\hat{X}_j 一一对应)

4.4 Combine:all2all + unpermute + weighted_sum(送回并聚合)

目标:把专家输出返回到 token 的原属 GPU,并将 top-k 多路输出按权重聚合为一个 token 输出。

按来源 GPU 反向打包(pack-back):

  • Dispatch 时每个样本带有其“来源 GPU + 来源 token 位置(src_slot)+ k_slot”
  • GPU jjO^j\hat{O}_j 按来源 GPU 分桶:
    • 形成 sendbackji\text{sendback}_{j\to i}

第二次 all2all

  • 所有 GPU 同时执行 all2all,GPU ii 收到所有返回样本集合 recvbacki\text{recvback}_i

本地还原与加权聚合:

  • GPU ii 对其本地每个 token tt,收集来自 kk 路的返回输出 Ot,1,,Ot,k{O_{t,1},\dots,O_{t,k}}
  • 按 router 权重聚合:
    • Yi[t]=r=1kwi[t,r]Ot,rY_i[t] = \sum_{r=1}^{k} w_i[t,r] \cdot O_{t,r}

Combine 后的输出排布:

  • GPU ii 得到:
    • YiRSP×HdimY_i \in \mathbb{R}^{\frac{S}{P} \times H_{\text{dim}}}
  • 输出仍是 按序列维度切分(SP),可直接送入下一层。