DeepGEMM 学习指南:面向初学者的 FP8 GEMM 库解析
优质博客¶
- CUTLASS & GPU CUDA 编程解析 https://research.colfax-intl.com/blog/
一、总体设计与原理¶
1.1 DeepGEMM 是什么?¶
DeepGEMM 是由 DeepSeek 团队开源的一个专注于 FP8 矩阵乘法(GEMM)的高效库。它旨在为深度学习中的矩阵乘法提供高效且简洁的实现,同时支持常规密集矩阵乘法和混合专家模型(MoE)中的分组矩阵乘法。DeepGEMM 最大的特点在于其极简的设计 —— 核心计算内核仅约 300 行代码,这使得阅读和理解变得相对容易。此外,它采用了运行时即时编译(Just-In-Time, JIT)技术,在安装时无需编译内核,所有 GPU 内核会在运行时根据需要动态编译。这种设计减少了安装部署的复杂性,并能够针对不同硬件和矩阵规模进行定制优化。
1.2 存在的意义和目标¶
在人工智能和高性能计算领域,矩阵乘法是许多算法(如神经网络前向/反向传播、Transformer 自注意力等)的核心。传统 FP32/FP16 精度的 GEMM 在计算效率和内存带宽上遇到了瓶颈。FP8 是一种仅 8 位的浮点数格式,通过牺牲部分精度换取更快的计算和更低的显存占用。DeepGEMM 正是为了充分利用新硬件对低精度的支持而诞生:NVIDIA Hopper 架构(如 H100 GPU)原生支持 FP8 运算和引入了 Tensor Core 加速的新特性,例如张量内存加速器(Tensor Memory Accelerator, TMA)。DeepGEMM 专门针对这些新特性进行了设计,能够充分发挥 Hopper 架构的性能潜力。通过利用 TMA 进行高效的数据搬运以及 Hopper 的 warp 级矩阵乘法指令(Tensor Core),DeepGEMM 实现了对 FP8 GEMM 的高效计算。同时,FP8 精度带来的数值范围缩小问题通过细粒度缩放技术加以解决,避免了低精度下出现的数值溢出或下溢,保障计算稳定性。总的来说,DeepGEMM 诞生的目的在于提供一个干净、易读且性能卓越的 FP8 矩阵乘法实现库,为大型模型训练和推理(尤其是大模型和 MoE 场景)提供支持。
1.3 架构设计与特性¶
DeepGEMM 从 NVIDIA 官方的 CUTLASS 和 CuTe 库中汲取了一些概念,但避免过度依赖其中复杂的模板和代数结构。相反,它在架构上追求极简,仅保留少数核心内核函数,以更直观地展现 GPU 高性能内核优化技巧。尽管设计轻量,DeepGEMM 在各种矩阵规模上的性能可媲美甚至超越高度优化的专家级库实现(某些测试中较基于 CUTLASS 的实现快 1.4~2.7 倍)。这一成绩证明了其架构设计的有效性,也是初学者学习 GPU 编程优化的绝佳实例。
1.4 DeepGEMM 的技术亮点¶
总结来说,DeepGEMM 具备以下几个关键设计与原理方面的亮点:
-
高效利用 Hopper 架构硬件
深度优化针对 NVIDIA Hopper GPU,引入并充分利用了 Hopper 架构的新特性 Tensor Memory Accelerator (TMA) 进行数据搬运。TMA 是 Hopper 提供的一种硬件级数据传输加速功能,可以在 GPU 内核中异步、高效地将数据从全局内存搬运到片上内存,从而大幅提升矩阵乘法的数据吞吐效率。同时,DeepGEMM 使用 Hopper 的张量核心(Tensor Core)执行矩阵乘累加运算,实现对 FP8 运算的原生支持。总体而言,DeepGEMM 的实现充分“贴合”硬件,最大程度发挥了 Hopper GPU 的算力。
-
细粒度缩放技术
由于 FP8 数值精度较低,在运算中容易发生上溢或下溢。DeepGEMM 引入了细粒度的缩放 (fine-grained scaling) 策略来解决这一问题。具体而言,它会将矩阵数据划分为较小的块,每个块使用各自的缩放因子,将 FP8 张量在计算前按适当比例放大或缩小,以保持计算的数值稳定性。这种块级别的动态量化确保即使在 8 位表示下也能覆盖足够的数值范围。例如,在 DeepSeek V3 的 FP8 训练中,每 128 个元素使用一个独立的缩放系数,在乘法累加时先将 FP8 数值乘以缩放系数提升精度,再进行张量核心运算。通过这种方式,DeepGEMM 有效避免了低精度带来的精度损失和随机误差,实现了稳定可靠的 FP8 计算。
-
支持分组矩阵乘法 (Grouped GEMM)
DeepGEMM 特别支持分组矩阵乘法,这在 Mixture-of-Experts (MoE) 大模型中非常重要。MoE 模型将不同数据分配给多个“专家”子模型处理,对于每个专家都需要进行自己的矩阵乘法。传统库如 CUTLASS 提供的 grouped GEMM 允许每组具有不同尺寸,但 DeepGEMM 针对 MoE 的特点做了特殊优化:仅在 M 维度上进行分组,要求 N 和 K 维度相同,从而简化并优化批量处理。DeepGEMM 提供了两种分组场景支持:其一是连续布局 (contiguous layout),用于训练前向或推理预填充阶段,此时每个专家处理的样本数不同但矩阵形状一致,通过将不同专家的输入在 M 维拼接形成一个大矩阵(中间需按要求对每段进行填充对齐)进行一次性乘法计算。另一种是掩码布局 (masked layout),用于推理的自回归解码阶段:由于解码时每个专家得到的 token 数不可预知,DeepGEMM 提供了带掩码的分组 GEMM 接口,允许传入一个 mask 张量来指示哪些位置有有效数据,仅对有效部分进行计算,从而避免无用计算。这种分组支持使 DeepGEMM 在大规模 MoE 模型中表现出色,充分利用了专家并行的优势。
-
运行时 JIT 编译设计
DeepGEMM 采用了全面的 JIT 编译策略。所有 GPU 核心计算代码并非在安装时预先编译为固定的二进制,而是推迟到运行时根据实际使用的矩阵大小和硬件情况再进行编译。在初次调用某种规模的 GEMM 操作时,DeepGEMM 会通过 NVCC 或 NVRTC 动态生成并编译针对该形状的优化内核,并将其缓存以供后续使用。这种设计带来了多重好处:一方面避免了用户安装时繁琐的编译过程(缩短安装时间、降低对编译环境的依赖),另一方面能够让编译器针对具体矩阵形状和 GPU 架构进行更激进的优化(例如不使用通用模板而直接展开循环、裁剪无用路径等),从而在运行时获得更高的性能。JIT 还允许根据不同 GPU(如 Hopper 及未来的 Blackwell 架构)选择或编译不同代码路径,充分利用新硬件特性。需要注意的是,为了保证 JIT 的高效运作,DeepGEMM 设计了轻量的 C++ JIT 模块,避免过多 CPU 开销,并支持将已编译内核缓存到磁盘(默认缓存目录为用户主目录下的
.deep_gemm,可通过环境变量配置)。用户也可以通过环境变量控制 JIT 行为,例如DG_JIT_USE_NVRTC=1切换为 NVRTC 编译以提高编译速度(可快至 NVCC 的 10 倍,但某些情况下生成的内核性能略低)。 -
性能与适用场景
得益于上述技术,DeepGEMM 在多种矩阵形状下实现了一流的性能。官方测试显示,相较于基于 CUTLASS 的 FP8 实现,DeepGEMM 可加速约 1.4× 到 2.7×。甚至在一些矩阵规模上,其性能达到或超过英伟达专家手工调优的库(如 cuBLAS)的水平。这种性能优势使它非常适用于深度学习模型的推理阶段,用于降低大型模型、实时应用的延迟;在混合专家 MoE 模型中,针对专家并行的优化可显著提升训练和推理效率;以及在高性能计算领域,用低精度加速科学计算和大数据分析。总结来说,DeepGEMM 的开源为深度学习和 HPC 领域提供了一个高性能又易于理解的 FP8 GEMM 解决方案,不仅解决了低精度计算中的关键问题(如精度损失),也为研究者和开发者学习 GPU 优化和 Hopper 架构提供了宝贵的参考资源。
二、模块划分与职责¶
DeepGEMM 项目的代码仓库经过精心组织,分为多个模块/目录,它们各自承担不同的功能,共同实现了上述特性。主要的模块和结构如下:
-
csrc 目录(CUDA/C++ 内核实现)
这是 DeepGEMM 的核心,实现所有矩阵乘法运算的底层代码。csrc 内包含 CUDA C++ 源文件,编写了 FP8 GEMM 的具体计算内核以及相关辅助函数。这里的代码负责定义线程块如何载入数据、调用张量核心指令进行矩阵运算,以及对结果进行处理(例如加上累加矩阵 C、写回内存)。DeepGEMM 的高性能 FP8 GEMM 算子以及分组 GEMM 等都在此实现。值得注意的是,由于追求简洁,csrc 中仅包含少量核心内核函数(正如前述大约 300 行的核心内核)。通过不同的模板参数或编译时宏,这些内核可以覆盖各种矩阵大小和布局组合,而无需大量重复代码。这部分代码高度优化,直接与硬件打交道,例如使用 TMA 异步加载数据、调用 Hopper 的 Warp-Level GEMM 指令等(下面章节将详细分析)。csrc 同时还包含 JIT 编译相关的 C++ 实现代码:为了在运行时生成/编译内核,csrc 提供了调用 NVCC/NVRTC 的接口、生成内核名称与配置、缓存编译结果等逻辑。在开发模式下,csrc 中的代码会通过 CMake 构建为一个可供 Python 动态调用的模块。
-
deep_gemm 目录(Python 接口模块)
该目录包含 Python 代码,封装了对底层 C++/CUDA 内核的调用接口,方便在深度学习框架中使用。安装完成后,deep_gemm 作为 Python 包导入。它提供了一系列函数和类,与 PyTorch 等框架的张量进行交互。例如,deep_gemm 模块中定义了各类 GEMM 函数(如
fp8_gemm_nt等),这些函数内部会检查所需的内核是否已编译、调用 JIT 编译模块生成内核(如尚未缓存),然后通过 PyTorch C++ 扩展机制调用编译好的 CUDA 内核执行计算,并返回结果。在 Python 接口层,还提供了一些实用工具函数,帮助用户准备数据或调整运行参数。例如:
transform_sf_into_required_layout用于将给定的缩放因子张量转换为内核要求的布局(如转置并做 TMA 对齐)get_tma_aligned_size查询 TMA 所需的对齐大小get_mk_alignment_for_contiguous_layout获取 MoE 分组连续布局在 M维需要的对齐粒度set_num_sms/set_tc_util等用于设置使用的 SM 数量上限和张量核心利用率等性能调优参数
这些接口让用户在 Python 层面即可方便地配置和使用 DeepGEMM 的功能,而不需要深入 C++ 代码细节。
-
scripts目录
包含一系列脚本和配置文件,用于构建、安装和开发。主要脚本包括:sh(编译构建项目脚本),install.sh(安装脚本)和 develop.sh(开发模式下的环境配置脚本)。例如,develop.sh 脚本会自动设置必要的包含路径(如 Cutlass 和 fmt 库)、编译 JIT 模块等。开发者可以阅读或执行这些脚本来快速完成环境配置和库的编译。在日常使用中,运行 install.sh 即可完成安装,然后通过 Python import deep_gemm 引入库。这些脚本模块确保了 DeepGEMM 的安装和开发流程简洁一致。
-
tests目录
包含若干测试用例和示例代码,验证 DeepGEMM 功能的正确性和展示接口用法。例如有 test_layout.py(测试不同布局、分组情况下的正确性)、test_attention.py(测试 MoE 专家组合、掩码 GEMM 以及 MQA 特殊 kernel 的正确性)等。初学者可以通过阅读和运行这些测试,了解如何调用 DeepGEMM 提供的API、如何准备 FP8 数据和缩放因子,以及对比计算结果的误差。这些测试相当于示例代码,演示了在 PyTorch 张量上使用 DeepGEMM 的方法,也是进行学习和验证环境配置的良好起点。
-
third-party目录
存放第三方依赖库代码。本项目使用了 NVIDIA 的 CUTLASS 库(子模块形式引入)和 {fmt} 格式化库。CUTLASS 提供了一系列 CUDA 模板和工具,用于实现高性能 GEMM。DeepGEMM 虽然没有直接使用 CUTLASS 的复杂模板结构,但仍然借鉴了其中的某些组件或概念(例如张量核心调度、WMMA 操作封装等)。将 CUTLASS 作为子模块包含,一方面方便使用其数据结构/常量定义,另一方面也可用于参考对比。同时包含的 {fmt} 库用于格式化打印和日志(DeepGEMM 的 JIT 过程可能会打印编译命令等调试信息,用到此库格式化输出)。third-party 确保了这些依赖在编译时可用,不需要用户手动安装。
-
构建配置
仓库根目录下有
requirement.txt和setup.py等文件。其中CMakeLists.txt定义了 CUDA 内核编译和 PyTorch 扩展构建的规则,setup.py 则允许通过 pip 工具安装本项目。在执行 install.sh 时,会调用 setup.py 将 DeepGEMM 构建为 Python 包并安装到环境中。因此用户既可以通过脚本直接构建,也可以将其集成到 Python 包管理流程中。
通过以上模块划分,DeepGEMM 实现了清晰的职责分离:底层 CUDA 核心专注于性能优化,Python 封装负责易用性和集成,测试和脚本确保正确性和可复现性。初学者在阅读代码时可以先从 Python 接口入手,逐步深入到 csrc 核心代码,在不同模块中理清各部分的职责和联系。
三、接口使用方法¶
DeepGEMM 为用户提供了简洁的 Python 接口来调用 FP8 GEMM 内核。要在自己的项目中集成和使用 DeepGEMM,建议按照以下步骤进行:
-
环境准备与安装
确保具备支持 DeepGEMM 运行的环境。由于 DeepGEMM 针对 Hopper 架构优化,目前需要 NVIDIA Hopper 或更新架构的 GPU(计算能力 SM90,如 H100,或 SM100 等新架构)。此外,需要 CUDA 12.3 及以上版本的工具包(建议 12.9 或更高以发挥最佳性能),Python 3.8+,以及支持 C++20 标准的编译器。运行 DeepGEMM 还依赖 PyTorch 2.1 或更高版本作为张量接口,以及包含的 CUTLASS 和 fmt 第三方库。
-
基本 API 调用
安装成功后,在 Python 代码中即可通过 import deep*gemm 引入库,并使用其提供的函数完成矩阵乘法计算。例如,要执行一个基本的非分组 FP8 矩阵乘法,可以调用
deep_gemm.fp8_gemm_nt函数。DeepGEMM 约定以fp8_gemm*{layout}命名不同布局的 GEMM 函数,其中{layout}表示输入矩阵是否转置的组合(N=不转置, T=转置)。注意: 对于 Hopper (SM90) 架构,目前 DeepGEMM 仅支持 NT 内存布局的内核(即第一个矩阵 A 为行优先未转置,第二个矩阵 B 为列优先,相当于数学上执行 )。因此在 Hopper 上应使用fp8_gemm_nt变体;而对于更新架构 (SM100),DeepGEMM 则提供了 NT、NN、TN、TT 等全布局支持。举例来说,调用:pythonD = deep_gemm.fp8_gemm_nt(A_fp8, B_fp8, C_bf16, scaleA)D = deep_gemm.fp8_gemm_nt(A_fp8, B_fp8, C_bf16, scaleA)其中 A_fp8 是 FP8 格式的左侧矩阵,B_fp8 是 FP8 格式的右侧矩阵,C_bf16 是累加矩阵(如 BF16 或 FP32,用于存放 ),scaleA 是 A 矩阵对应的缩放因子张量(下文详述)。该函数会计算 并返回结果张量 D(通常 D 会与 C 拥有相同的数据类型,例如 BF16)。
-
关于缩放因子
使用 FP8 计算时,每个矩阵需要配套的缩放因子以实现前述细粒度量化。DeepGEMM 要求用户显式提供缩放因子张量,尤其是左侧矩阵 A 的缩放因子。缩放因子的格式因架构不同而异:在 Hopper (SM90) 上,缩放因子用标准的 FP32 浮点表示,每个缩放值对应一定数量的 A 元素。而在更新架构 (SM100) 上,引入了一种紧凑表示 UE8M0(Unsigned E8M0)来存储缩放因子:将 4 个缩放值打包进一个 32-bit int,以提高内存访问效率。DeepGEMM 提供了辅助函数
transform_sf_into_required_layout,可将普通的 FP32 缩放因子列表转换为目标架构所需的转置对齐布局(Hopper 上基本为转置排布的 FP32 矩阵,Blackwell 上则打包为 int)。使用该工具函数可以简化缩放因子的准备工作。需要强调的是,DeepGEMM 并不会自动推断缩放因子,也不会在内部执行 FP32<->FP8 的类型转换,这些步骤需由用户在调用库之前完成。这意味着在调用fp8_gemm_*之前,用户应根据数据范围计算合适的缩放因子,将原始 FP16/BF16/FP32 张量量化为 FP8 格式(例如通过逐元素除以缩放因子并四舍五入到最近可表示的 FP8 值),并将数据类型转换为 torch.uint8 或 torch.int8 来存储 FP8 值。DeepGEMM 提供的 FP8 GEMM 函数会将这些 uint8/int8 数据按照缩放因子解释为 FP8 数值,在张量核心中进行乘法累加运算。由于库专注于 GEMM 本身,这些前置的数据处理和后续的结果反量化需要由用户或上层框架处理。 -
结果与输出
DeepGEMM 的 GEMM 运算形式是 ,其中 D 和 C 通常共享内存(即结果写入传入的 C 张量或返回值中)。如果用户希望像标准矩阵乘那样得到 纯乘积,也可以传入零矩阵作为 C(或在调用前将 C 清零)。DeepGEMM 支持输出以较高精度保存,例如 C 和 D 使用 BF16/FP16/FP32,这样可以在 FP8 乘法后保留更多累加精度。实际上,DeepGEMM 的某些接口名称中包含数据类型信息,例如
gemm_fp8_fp8_bf16_nt表示 A/B 使用 FP8 输入,输出累加到 BF16 格式。一般来说,您应选择输出类型足够容纳累加结果,以避免进一步精度损失。
-
-
高级功能接口
除了基本的二矩阵乘法,DeepGEMM 还提供了一些针对特殊场景的接口:
-
分组 GEMM
针对 MoE 的场景,DeepGEMM 提供了分组矩阵乘法接口。常用的是 M 维连续分组接口,例如
m_grouped_fp8_gemm_nt_contiguous。使用该函数前,用户需要将多个专家的输入矩阵在 M 维拼接成一个大矩阵 A(同时拼接对应的输出 C 和缩放因子),并确保每个专家的数据块长度满足对齐要求。DeepGEMM 提供了get_mk_alignment_for_contiguous_layout()函数来获取此对齐边界;通常需要对每个专家的数据量做填充(pad)到最近的对齐大小。例如若 M 轴分块对齐要求是 16,那么每个专家的 token 数需填充到 16 的整数倍。调用m_grouped_fp8_gemm_*_contiguous后,DeepGEMM 内核会在一个大 Kernel 中处理所有专家的数据,相当于批处理多个 GEMM,但比逐个调用效率更高。对于掩码分组(不定长序列的情况,如推理解码时),可以使用m_grouped_fp8_gemm_nt_masked接口。该接口需要额外传入一个 mask 张量(例如形状为 [总 M 长度] 的布尔或 0/1 张量),标记大矩阵 A 中哪些条目是真实数据,哪些是填充。内核据此只计算有效部分,从而避免了对填充项的冗余计算。这在需要通过 CUDA 图执行且 CPU 无法及时提供分组信息的场景(如自回归生成)特别有用。还有针对 K 维分组的接口,例如k_grouped_fp8_gemm_tn_contiguous,用于 MoE 模型的权重梯度反向计算(固定 M 和 N,K 维按专家分组)。总之,DeepGEMM 的分组接口使得在不同专家之间共享形状的矩阵乘法能够高效执行,开发者需要根据场景选择合适的函数并准备好相应格式的输入数据。 -
特殊算子(MQA Kernel)
除了通用的 GEMM,DeepGEMM 还发展出一些特定用途的核函数。例如在 DeepSeek v3.2 的检索场景中,引入了**多查询注意力得分(MQA)**的计算内核。DeepGEMM 提供了
fp8_mqa_logits(非分页)和fp8_paged_mqa_logits(分页,用于流式解码)两个函数,实现了一种融合计算:对查询矩阵 和键值矩阵 进行点积得到 logits,同时对结果应用 ReLU 激活并乘以权重,再对各头求和,直接输出最终得分。这一系列操作原本需要多步完成,而 DeepGEMM 将其融合在单个CUDA核中完成,大幅提升了这一路径的性能。虽然一般用户可能不需要手动调用这些特殊接口,但它展示了 DeepGEMM 框架的扩展性——可以根据特定业务需求开发定制的内核并集成在库中,从而避免 Python 层多次调用的开销。例如,fp8_mqa_logits接受的输入包括 FP8 格式的查询向量 q(形状[seq_len, num_heads, head_dim])、FP8 键值对 kv(键值被特别组织为两个张量:FP8 的键值矩阵和 FP32 的缩放因子向量)以及 float 格式的注意力权重等,直接在 GPU 上完成形如 的计算并累加。这类例子体现了 DeepGEMM 不仅能做基础 GEMM 运算,也能通过 JIT 框架方便地扩展融合复杂算子。 -
实用工具与环境变量
前面提及的工具函数可以帮助处理缩放因子和张量对齐等需求。除了函数接口外,DeepGEMM 还支持通过环境变量配置一些运行参数:
DG_JIT_DEBUG:设置为 1 可以打印 JIT 编译的调试信息,默认 0 不打印。DG_JIT_CACHE_DIR:指定 JIT 缓存目录,默认在用户主目录下.deep_gemm。DG_JIT_USE_NVRTC:设置为 1 启用 NVRTC 作为 JIT 编译后端,加速编译过程。DG_PRINT_CONFIGS:设置为 1 则在每次 GEMM 调用时打印所选用的内核配置(如线程块大小、分块维度等),可用于了解 JIT 选择了怎样的优化策略。 通过这些配置选项,用户和开发者能够更透明地观察 DeepGEMM 的行为,并在需要时调整其 JIT 编译策略和调优参数(例如在调试或性能分析时打开日志)。
-
-
与其他项目集成
由于 DeepGEMM 提供的是 PyTorch 风格的扩展接口,集成方式非常直接——在 PyTorch 中可将其作为替代矩阵乘法的实现。例如,对于支持 FP8 训练的变体 TransformerEngine,DeepGEMM 可以作为其 GEMM 内核的替换,以加速 FP8 张量的矩阵乘。实际使用时,只需确保在使用 FP8 算子的代码路径调用 DeepGEMM 提供的函数即可。例如在一个自定义的 Linear 层 forward 中,使用 DeepGEMM 完成 FP8 weights 和 FP8 activation 的乘法,再进行后续操作。社区已有项目(如 vLLM 等)集成了 DeepGEMM,通过预先 warm-up(预编译)所有将用到的内核来减少首次调用延迟。这说明 DeepGEMM 在大型推理服务中也是可行的。总之,DeepGEMM 的接口设计贴近 PyTorch 原生操作,几乎不需要复杂的对接,即可在现有深度学习项目中替换/插入 DeepGEMM 来提升低精度矩阵乘法的性能。
四、底层算子实现¶
本节将深入分析 DeepGEMM 核心 GEMM 算子的内部实现原理,包括其调度方式、张量处理和融合策略等。这对于理解其高性能来源很有帮助。
4.1 矩阵乘法的块级调度
类似大多数高性能 GPU GEMM 实现,DeepGEMM 采用了块(tile)级划分和调度策略。即将大矩阵的乘法任务分解为许多小的子矩阵块的乘法,交由并行的线程块 (thread block) 来处理。具体而言,DeepGEMM 会选择一定的块尺寸 (例如 或 等),每个线程块负责计算输出矩阵中的一个 tile。通过这种划分,大矩阵乘运算被分摊到多个 SM 上并行完成。DeepGEMM 内部有一套启发式算法或预定义配置来选择最佳的块尺寸和分块配置(称为“配置(config)”)针对给定的矩阵形状。在 JIT 编译时,库会根据矩阵维度 (M, N, K) 以及 GPU 架构,选择最优的内核配置并编译相应代码。这种做法类似于 cuBLAS 等库根据输入大小选择不同实现,以在线程并行度和数据局部性之间取得平衡。DeepGEMM 提供了打印所选配置的选项DG_PRINT_CONFIGS,用户可以看到例如块大小、分块数目等信息。总的来说,块级调度确保了大矩阵被充分并行地覆盖,又能让每个线程块处理适当大小的数据以充分利用片上存储和计算单元。
-
Warp 组织和张量核心利用
在每个线程块内部,DeepGEMM 会进一步组织线程进行计算。Hopper 架构引入了 Warp 级组矩阵乘累加 (WGMMA) 指令,可让一组 Warps 协同执行一个较大的矩阵乘运算。例如让一个线程块内的若干个 warp 组成一个“warp 组”,共同完成子矩阵乘法。每个 warp 负责其中一部分计算,多 warp 的结果累加形成完整的 tile 输出。这相比传统每个 warp 计算独立 16×16 片块的方式效率更高,因为 WGMMA 可以减少跨 warp 同步和数据拷贝。此外,Hopper 的张量核心支持 FP8 输入、FP32 累加的两阶段乘累加模式。DeepGEMM 在实现中让张量核心以两阶段累加的方式工作:FP8 的乘法在硬件张量核心上执行,中间结果累积在更高精度(FP16/FP32)寄存器中,最后将累加和转换回需要的格式输出。这样一来,即使每次乘法只有 8 位精度,累加过程也不会严重累计误差,保证结果精度。实际实现中可能将 FP8 值在加载后提升为 FP16 参与运算,并使用 CUDA 的 FFMA(Fused Multiply-Add)指令进行融合乘法加法,实现累乘加速。值得一提的是,为了进一步优化 FP8 累加的精度和性能,DeepGEMM 曾对生成的 SASS(GPU 汇编)代码做微调,插入/交织 FFMA 和 FADD 指令以充分利用执行单元管线,提高吞吐。这种手工调整通常用于确保在张量核心计算间隙插入纯标量加法,以平衡各单元利用率。据开发者透露,NVCC 12.9 版本开始编译器已经能自动做 FFMA 指令交织优化,因此 DeepGEMM 后续去除了手动修改SASS的步骤。这体现了 DeepGEMM 针对 GPU 微架构的深入优化和与时俱进的调整。
-
内存访问与 TMA 加速
DeepGEMM 的高性能还来源于对显存访问的高效处理。传统的 CUDA GEMM 实现会将需要的矩阵块从全局内存拷贝到共享内存,然后供 warp 读取计算。Hopper 架构提供的 Tensor Memory Accelerator (TMA) 使这个过程更加高效。DeepGEMM 内核使用 TMA API 发起异步全局内存读取,将 A、B 矩阵的当前 tile 块加载到共享内存。与早期架构上的 cp.async 指令类似,TMA 可以隐藏内存延迟,实现加载和计算的并行流水。但 TMA 更智能,它支持三维(tile)内存访问,可以一次搬运一个矩阵子块(例如 128×K 的片段),并在后台将数据分片送达共享内存指定位置。DeepGEMM 利用了 TMA 的这些能力,实现双缓冲流水线:线程块通常分配两个共享内存缓冲区,一个用于当前计算的数据,另一个用于预取下一个 K 分块的数据。当 warp 组在对第 i 批次的 K 块执行矩阵乘时,TMA 会同时将第 i+1 批次所需的数据加载到另一个缓冲。计算完切换 buffer 即可继续,不需等待数据,从而计算与内存传输重叠。这种设计确保了张量核心运算始终“吃满”数据,尽量不因为等待内存而闲置,大大提高了整体吞吐。
-
细粒度缩放融合
前面提到每个矩阵 tile 可能需要自己的缩放因子。DeepGEMM 在内核实现中将缩放操作与数据加载或乘法运算融合,减少额外开销。具体做法可能是在将 FP8 数据读入寄存器或共享内存后,立即乘以对应的缩放因子(存储在寄存器或共享内存中)以转换为 FP16/FP32 实数,再参与矩阵乘法。由于每个缩放因子对应固定数量的元素,GPU 可以利用向量化指令一次性对多个元素应用相同缩放。这种融合避免了在外部逐元素转换,可以看作内核的“前处理”部分。对于不同架构的数据格式,DeepGEMM 会采用不同方式:在 Hopper 上,缩放因子矩阵本身以 FP32 存储并通过 TMA 读取,内核对每个 FP8 元素乘以 FP32 缩放;在新架构上,由于缩放值打包为 int,需要先解析出 4 个缩放值,再将对应 4 组 FP8 元素转换。在任一种情况下,这一步骤都被安排在计算流水线中,不引入显式的 CPU 参与。最终,使得 FP8->高精度的转换和矩阵乘累加是一个连续的流,而不是分两步的过程。
-
C + GEMM 融合(Epilogue)
DeepGEMM 的 GEMM 运算包括了 的累加形式,即将已有的矩阵 C 加到乘积上。这一操作在内核实现中作为**融合的 Epilogue(尾段)直接进行:当一个 tile 的乘法部分完成后,每个线程直接从全局内存读取对应位置的 C (或在开始时已通过 TMA 加载到共享内存),将其加到计算的结果上,然后写回全局内存成为 D。通过将加法融合在内核内,避免了计算完乘积后再启动一个核来加法的额外开销。值得一提的是,如果 C 矩阵为零矩阵,上述加法等效于无操作,这种情况下 DeepGEMM 实际上就实现了 的纯乘法。在实现中可能会判断 C 是否为 None 或零,从而跳过加载和加法以节省内存带宽。不过对于统一性,很多实现通常总是执行一次加法操作。对于输出结果,DeepGEMM 也可能在 Epilogue 阶段进行必要的饱和(clamp)**处理,即确保 FP8 输出不溢出可表示范围(如果输出也是 FP8 的话)。社区有人在阅读测试代码时注意到 DeepGEMM 在输出比较时有 clamp 操作。推测其内核在将累加结果写回时,可能对结果做了 clamp 到 FP8 范围(0xFF 表示 inf 之类)以便结果与参考实现对齐。总之,这些尾部处理确保了结果正确又高效地存储。
-
特殊算子的实现考量
对于 MQA 这样融合了非线性算子的特殊 kernel,DeepGEMM 在实现上采用了内核内融合多步运算的思路。比如 MQA logits 的 kernel,会在单个 CUDA kernel 中完成查询与键值的点积、ReLU 激活、乘以权重、跨 head 求和等操作。实现这类 kernel 时,需要仔细安排各步在线程中的计算次序以及如何利用共享内存存放中间结果。例如先计算 得到每个 head 的局部 logits,存在寄存器或共享内存,然后立即对该向量执行 ReLU(这一步每个元素独立,易于并行),接着每个线程再取对应的权重乘上去,最后通过线程间协作(sum reduce)将多 head 结果求和得到标量输出。所有这些操作串联在一个 kernel 中,省去了多次内存读写和 kernel 启动开销。由于 DeepGEMM 采用 C++20 和 CUDA 结合开发,编写这样的融合 kernel 相对容易:可用模板或宏参数化打开/关闭某些步骤,从而重用矩阵乘的主体部分。比如可以想象 DeepGEMM 可能在内部有“Epiligue”模板,用于定义对矩阵乘结果的后处理(默认是加 C 并写出,对于 MQA 则是 ReLU、乘权重、求和等)。这样,核心乘累加逻辑不用重复编写,只需定制不同的 epilogue 操作即可。因此学习 DeepGEMM 源码时,初学者可以留意这类通过模板参数实现不同功能融合的模式。这也是 CUDA 内核实现常用的融合策略:通过在一次内核中完成尽可能多的连续算子,来提升总体效率。
综上,DeepGEMM 底层算子的实现结合了先进的硬件特性(TMA、张量核心、WGMMA)、经典的优化手段(块划分、双缓冲流水、寄存器高精度累加)以及合理的融合(缩放、加法等合并到内核)。这种实现让它在仅约 300 行核心代码中达到了媲美甚至超越高度专业优化库的性能。阅读其源码,可以看到许多 NVIDIA GPU 编程的技巧在其中得到体现,是学习 CUDA 高性能编程的宝贵素材。
五、CUDA Kernel 细节¶
在上一节宏观了解了 DeepGEMM 算子的实现原理后,本节进一步关注具体的 CUDA kernel 编程细节,包括线程组织方式、调度机制以及性能优化技巧等。这些细节对于想深入理解甚至修改内核代码的同学非常重要。
-
线程块和 Warp 组织
DeepGEMM 的 CUDA kernel 通常以线程块 (thread block) 为调度单位,每个线程块处理输出矩阵的一个子块(tile)。在实现中,需要决定每个线程块使用多少线程以及如何划分任务。DeepGEMM 针对 Hopper 的特点,可能采用了较大的线程块配置,例如 256 或 512 个线程一个 block,以充分利用 SM 的资源。线程块内部又划分为若干 Warp(每个 Warp 固定 32 线程)。由于 Hopper 引入了 Warp 组 (warp group) 的概念,可以让多个 Warp 协同执行一个 WGMMA 操作,DeepGEMM 很可能将一个线程块内的 Warps 分组成几个组,每组负责计算 tile 的一部分。例如,如果输出 tile 为 128×128,一个 warp 组可能负责其中的 64×128 区域,另一组负责剩余 64×128 区域,然后把结果合并。【注:具体数值推测】Warp 之间通过共享内存或寄存器通信配合。通常,一个 WGMMA 指令即可由一组 warp 完成较大的计算,因此 DeepGEMM 在代码中需要使用 CUDA 提供的 合作组 (cooperative groups) 或底层的 __syncwarp 等机制确保同组的 warp 同步执行矩阵乘操作,并在需要时跨 warp 交换数据或累加部分和。线程组织的目标是让所有计算单元都繁忙工作且最大化数据重用。通过让多个 warp共同加载使用一块数据,可以减少重复的全局内存读取,提高缓存命中。
-
双缓冲流水线
在前面的实现原理部分提到,DeepGEMM 内核实现了计算和数据加载的重叠。这种流水线通常通过双缓冲技术实现。在 CUDA 代码层面,这意味着使用两个片上缓冲区(如两个共享内存数组)交替存放来自全局内存的 tile 数据。在实现中,会将内核的计算逻辑分为多个阶段:
- 预取阶段: 使用 TMA 或者 cp.async 指令,将下一批 A、B 子块读入共享内存 Buffer1。
- 计算阶段: 使用上一批已在共享内存的数据 Buffer0,执行矩阵乘累加。
- 交换阶段: 计算阶段结束后,Buffer0 和 Buffer1 角色对换,继续预取下一批数据到空闲的缓冲区,同时计算当前批次,以此往复。
整个过程中需要精心插入同步操作。例如,在使用 cp.async 时通常需要 __syncthreads 或者特殊的释放屏障指令来确保数据已到达共享内存再开始计算下一阶段。而 TMA 提供了更高级的同步机制,可能使用片上信号量,无需显式同步。DeepGEMM 作为基于 Hopper 的实现,很可能利用了 TMA 的 Tensor Semaphore 功能来协调双缓冲,这在 CUDA 代码中通过特殊 pragma 或 intrinsic 函数实现。对初学者而言,这部分代码可能看起来较复杂,但理解其意图就是实现 Copy-Compute Overlap。可以类比为现实中的“左手倒水,右手煮饭”,以达到总用时最短。深入代码时,可以尝试定位这样的模式:一组线程触发异步加载,然后做计算,再同步,然后再加载——这些就是双缓冲流水线的实现细节。
-
寄存器和共享内存优化
CUDA kernel 的性能很大程度取决于对寄存器和共享内存的高效使用。DeepGEMM 在这方面也做了大量优化:
- **寄存器分配:**FP8 运算会涉及将 8 位值扩大为 16/32 位,所以每个线程在计算过程中需要额外的寄存器来保存扩大后的值和累加和。DeepGEMM 通过 CUDA 编译器自动分配寄存器,并可能使用 #pragma unroll 等指导编译器展开循环,减少寄存器复用导致的读取开销。同时,为避免寄存器溢出导致溢出到本地内存(会降低性能),DeepGEMM 内核会控制每个线程处理的元素数量,使所需寄存器数目在硬件限制内。Hopper 每个 SM 支持更多寄存器,DeepGEMM 可能也调整了每线程计算份额以利用这一点。
- 共享内存布局:DeepGEMM 在将数据搬入共享内存时,会选取一种优化的布局。例如,为了让后续张量核心访存对齐,A 矩阵 tile 通常以行优先方式存放,B 矩阵 tile 以列优先方式存放。这样 Warp 在读取时就是连续地址,可充分利用内存带宽。DeepGEMM 还计划实现共享内存乱序 (swizzling) 技术,即在将结果写入共享内存或从共享内存读出时,对地址做一个重新排列,以降低 bank 冲突和提高并发访问效率。这在 ROADMAP 中被提及(“Shared memory swizzling for output”),可能尚未完全实现,但未来版本会加入。初学者可以留意代码中对共享内存索引的计算,理解其如何排列线程和数据以避免访问冲突。
- 内存对齐和边界处理:DeepGEMM 要求输入矩阵的某些维度满足对齐要求,例如上文提到的 M 维必须是 16 的倍数等。如果输入不满足,通常通过补零填充解决。这简化了内核实现,可以不必编写复杂的边界检查分支,从而消除分支开销。事实上,在 DeepGEMM 当前实现中,如果矩阵尺寸不是理想值,它仍然会按照最接近的对齐大小执行计算,只是多算的部分对应填充位,会在结果中被丢弃或通过mask忽略。例如 M 不是 block 大小的倍数时,kernel 会计算到下一个倍数,然后多出来的结果不会被使用。这种做法虽然有少量冗余计算,但避免了条件判断导致的 warp 分岔,整体更高效(后续 Roadmap 中提及会“Skip useless computation on M”,也许将来会更智能地跳过填充部分)。对于 K 维度,Hopper 的 WGMMA 也要求 8 或 16 的倍数对齐,DeepGEMM 通过让用户确保 K 可被整除(或在文档中注明限制)来简化实现。如果出现不能整除的情况,内核可能也会执行超出部分再忽略,所以通常上层会避免这种不匹配的情况。
-
性能调优策略
DeepGEMM 的高性能还源于一些灵活的调优机制:
- **Tensor Core 利用率 (tc_util):**提供的
set_tc_util接口允许用户设置一个大概的张量核心利用率参数。这可能影响 JIT 内核配置的选择。例如在矩阵规模很小的时候,强行使用过大的线程块和过多并行可能导致 Tensor Core 负载不高甚至浪费资源。通过调整 tc_util,DeepGEMM 也许会采用不同的分块策略来取得更好吞吐。虽然具体实现细节未公开,但提供这个接口表明 DeepGEMM 考虑了性能与并行度的平衡问题。 - SM 利用控制:
set_num_sms接口允许限制使用的 SM 数目。在某些多工作负载场景或测试场景下,用户可能希望 DeepGEMM 不要占满所有 SM,从而留出部分 GPU 给其他任务。通过限制 SM 数,DeepGEMM 内部会据此调整每次 launch 的线程块数量,避免占用全部 SM。虽然大多数用户不会用到这个设置,但它体现了对资源调度的细粒度控制,有助于在多流并发情况下优化性能。 - **JIT 编译参数:**DeepGEMM 在 JIT 编译 CUDA 内核时,可能会根据不同架构启用特殊的编译标志以进一步优化性能。例如 Hopper 的 NVCC 12.3+ 已支持新的WGMMA/TMA,用不同 PTX 优化选项可能产出更优 SASS。DeepGEMM 提供了
DG_JIT_PTXAS_VERBOSE和DG_JIT_PRINT_COMPILER_COMMAND等环境变量,可以让开发者查看底层编译器优化信息。这些对于深入分析性能瓶颈、观察寄存器分配、指令排布都很有帮助。当需要调优时,开发者可以开启这些日志,结合 Nsight 等性能分析工具,对内核的瓶颈(比如是否内存受限或计算受限)作出判断,进一步优化。 - **与 CUTLASS 的比较和借鉴:**DeepGEMM 作为一个简化的实现,也受益于对比 CUTLASS 等成熟库的性能。开发者在 Issue 中曾提问 DeepGEMM 与 CUTLASS 实现的区别。DeepGEMM 的作者指出,其实现借鉴了CUTLASS 3.6 中对 Hopper FP8 GEMM 的一些做法但进行了仔细的优化。例如 CUTLASS 的 grouped GEMM 在 M 维并没有特别优化,而 DeepGEMM 针对 MoE 做了更高效的分组设计。通过这种有的放矢的改进,DeepGEMM 在特定场景下超越了 CUTLASS 的性能。这对我们学习 CUDA kernel 优化也是启示:基于成熟方案,针对具体应用特点精简和优化,往往能取得更好效果而不引入额外复杂性。
- **Tensor Core 利用率 (tc_util):**提供的
总体而言,DeepGEMM 的 CUDA 内核细节展现出顶尖的 GPU 编程技巧。对于初学者,虽然一开始完全读懂这几百行高度优化的 CUDA 代码可能有难度,但正因为代码量有限且结构清晰(无过深的模板嵌套),反而是学习研究 GPU 内核的绝佳材料。通过理解其线程划分、数据流、同步和优化策略,读者可以对高性能 GEMM 内核的设计有一个全景的认识。
六、学习建议¶
DeepGEMM 将复杂的 FP8 GEMM 实现浓缩在一个相对小巧的代码库中,非常适合初学者循序渐进地学习 GPU 编程优化。以下是一些学习路径和建议,帮助读者从入门到深入掌握并参与开发贡献:
-
**掌握基础背景:**在深入代码之前,确保了解基本概念,例如矩阵乘法原理、FP8 数据格式(E4M3、E5M2等)的特点,以及 CUDA 编程基础(线程/ warp、共享内存、同步等)。可以参考本指南开头对 FP8 和 DeepGEMM 目标的介绍。如果对 NVIDIA Hopper 架构不熟悉,建议先阅读一些关于 Tensor Core 和 TMA 的资料。例如,NVIDIA 官方的 CUTLASS 文档和 Hopper 架构白皮书都讨论了 WGMMA 指令和 TMA 功能的使用。这些背景知识将有助于理解 DeepGEMM 代码中的专业术语和设计决策。
-
**阅读项目文档和示例:**通读 DeepGEMM 仓库的 README 文档和 News/Roadmap 部分,以了解项目的全貌和最新进展。特别是 Roadmap 列出了未来计划支持的特性(如增加 BF16 kernel、Ampere 架构支持等),这可以让你了解项目目前的限制和发展方向。接着,运行仓库中的测试例子(如 tests / 目录下脚本)。通过实际运行 test_core.py、test_layout.py 等,你可以验证你的环境配置是否正确,并观察这些测试如何调用 DeepGEMM 接口以及对结果的检查。这些测试既能帮助理解 API 用法,也能让你对 FP8 计算的精度和缩放因子作用有直观认识。
-
**从 Python 接口入手阅读代码:**打开 deep_gemm 目录下的 Python 源码(例如 init.py 或其他模块),看看 DeepGEMM 在 Python 层是如何封装的。通常这里会有对 C++ 扩展函数的 python 壳,以及 JIT 编译触发的逻辑。例如,找到
fp8_gemm_nt的 Python 定义,了解它接受哪些参数、是否有类型检查或前置处理。观察 Python 层是直接将参数传递给 C++ 函数,还是在调用前做了例如transform_sf_into_required_layout这样的处理。这一步可以帮助你理清 DeepGEMM Python API -> C++ 内核 的调用流程,也对理解参数(特别是缩放因子、布局等)有帮助。如果 Python 层有文档字符串或注释,仔细阅读。这部分代码通常简短且直观,非常适合作为切入点。 -
**理解JIT 编译机制:**深入看看 csrc 或 deep_gemm 中与 JIT 相关的代码。例如,也许在 Python 层有调用
torch.utils.cpp_extension.load或在 C++ 层有使用 NVRTC 编译内核字符串的实现。找出 DeepGEMM 在 runtime 是如何生成和编译内核的。可能存在一个专门的 JIT 管理模块,处理缓存文件名生成、调用 NVCC/NVRTC 编译等。当你发现相应代码时,尝试运行小规模矩阵乘,看它是否在 HOME/.deep_gemm 目录生成了二进制文件。这有助于你理解 JIT 编译的触发时机和缓存策略。如果代码中有打印编译命令的选项(如前述环境变量),可以尝试打开,观察实际编译命令行和选项,这将加深你对底层实现细节的认识。 -
**针对核心内核代码进行攻关:**找到 csrc 目录中实现 GEMM 的核心 CUDA C++ 源文件(可能名字中含有 gemm 或 wgmma 之类)。打开这个文件,先总体浏览代码结构。通常可以按以下顺序分析:
- **模板声明或宏:**看开头是否有宏定义BLOCK_SIZE_M, N, K等,以及模板参数,如数据类型、布局等。理解这些参数怎样影响后续代码生成。
- **Kernel 函数:**找到 global 函数(CUDA kernel)。它可能有诸如 gemm_kernel<…>(…) 的形式。分析该函数内做了哪些步骤,例如:共享内存声明 -> 计算全局线程块索引 -> 循环迭代 K 块 -> 调用 WGMMA 指令执行乘法 -> 写结果 等。可以尝试将之前性能原理部分的内容与实际代码对应起来。
- **WGMMA/WMMA 调用:**寻找类似
wgmma.mma_sync或wmma.load/mma/store之类的语句。如果有,它就是张量核心运算部分。看看周围代码,了解输入矩阵 tile 是如何加载到寄存器或者片上内存,并传递给这些指令的。 - **TMA 调用:**若代码使用了 TMA,需要寻找 NVIDIA 提供的 TMA API 使用。Hopper TMA 在 CUDA C++ 中可能以
nvcuda::wmma::tma::load这样的形式出现。如果找到,仔细阅读这些调用的参数,理解它在加载哪个矩阵的哪一块数据到哪儿。结合 CUTLASS 提供的示例或 NVIDIA 博客,对照理解 TMA 调用的含义。 - **同步和流水:**找出是否有
__syncthreads()或CG::wait_grou等同步语句,以及双缓冲切换的逻辑。可以尝试在纸上画出时间轴:线程块第一个阶段加载 A/B,第二阶段计算,同时预取下一块,如此往复,并标注出代码中对应的位置。这对掌握并行流水至关重要。 - **Epilogue 阶段:**在 kernel 结尾,看如何处理累加的 C 和写出 D。如果看到有判断或特殊处理 masked 的情况,理解其逻辑,比如忽略填充元素。 通过这样分块攻关,你可以把一整段复杂的 CUDA 内核实现拆解为若干部分各个击破。尽管可能一开始仍有不懂的细节(比如一些 PTX 伪指令或 intrinsic),不妨做记号稍后查询相关资料(如 NVIDIA 官方 PTX 文档)。每理解一处,都将使你对 GPU 内核有更深体会。
-
利用参考资料辅助理解
学习 GPU 底层优化有时需要参考一些外部资料加深理解。例如:
- 当遇到 WGMMA 指令相关代码,可以阅读 Colfax Research 的《Fast Matrix-Multiplication with WGMMA on NVIDIA Hopper GPUs》系列教程,了解 WGMMA 的工作原理和用法。
- 对 TMA 如仍困惑,可以参考 Colfax 的《Mastering the Tensor Memory Accelerator (TMA)》技术博客,该文详细解释了 TMA 的机制和编程模型。
- CUTLASS 官方文档和示例代码也是宝贵资源。特别是 CUTLASS 3.x 针对 Hopper 的实现,你可以对比 DeepGEMM,看哪儿相似哪儿不同。CUTLASS 的代码可能复杂很多,但文档里关于 tile 划分、pipeline 的说明是通用的,值得一读。 利用这些资料与 DeepGEMM 代码相互印证,会大大提高你的学习效率。每当看到 DeepGEMM 某段实现不理解时,尝试搜索相关关键词(例如 “WGMMA example”“CUDA TMA usage” 等),通常能找到解释原理的文章。把理论和实践对应起来,你对代码的理解会豁然开朗。
-
动手实验与性能验证
理论结合实践才能真正掌握。可以尝试自行修改或扩展部分代码来检验你的理解。例如:
- 写一个小的 FP8 GEMM Python 测试脚本,比较 DeepGEMM 的输出和 PyTorch FP32 GEMM 的输出,验证误差随缩放因子变化的情况,加深对缩放的认识。
- 修改
DG_PRINT_CONFIGS环境变量为 1,跑不同大小的矩阵乘法,观察 DeepGEMM 选择的 tile 配置,并总结模式。例如 不同范围下,block 尺寸如何变化。这练习有助于理解 JIT 调优策略。 - 如果具备一定 CUDA 基础,尝试在 csrc 中添加一些简单的调试输出(如使用 printf 输出线程块索引或中间计算值)。虽然 GPU 上调试不易,但小范围的打印可以帮助确认某些分支是否执行、某些计算次数等。不过要注意这些调试代码可能影响性能甚至需要裁剪矩阵规模以免输出过多。
- 基于你的理解,对内核做一些小改动看看效果。例如,更改一个 BLOCK_K 尺寸(确保对齐约束),重新编译运行测试,看性能或正确性有何变化。这种实验能将抽象的参数变动与实际性能结果联系起来,体会优化的敏感度。
-
参与开发和贡献
当你对 DeepGEMM 有了深入理解后,不妨尝试为项目做出贡献。这不仅巩固所学,还有助于开源社区。可以从以下方向着手:
- **完善文档和注释:**初学者的视角很适合发现文档中的不足之处。比如可以补充中文 README 或者注释说明,这对于后来的新手非常有价值(DeepGEMM Roadmap 中也提到需要“Polish docs”)。你可以提交 Pull Request 完善用法示例、说明缩放因子的计算方法等。
- **补充测试:**设计一些新的测试用例,尤其是在边界情况、不同数据类型(比如将来 BF16 支持)方面的测试。健全的测试会让库更稳健。
- **性能改进:**如果你对某部分实现有新的想法,或者发现了潜在的瓶颈,可以讨论并尝试优化。DeepGEMM 团队表示在某些方面性能还有提升空间,欢迎感兴趣的人提交PR。例如,你可以研究 Roadmap 上提到的“跳过无效计算”或“支持 Ampere 内核”,这些都是实实在在可推进的方向。一旦实现,将对社区有很大帮助。哪怕是优化一点点(比如调整寄存器使用或增加一个特殊 shape 的 kernel),都可能提升特定场景下的性能。
- **兼容性扩展:**尝试让 DeepGEMM 支持更多环境。例如现在仅支持CUDA 12.x/Hopper,如果你的开发环境有安培卡(A100 等),可以尝试按照类似思路编写 Ampere 的 FP16/INT8 核心(虽然 FP8 硬件支持不完善,但也许可用模拟方式)。提交这样一个扩展 PR 会极大拓宽 DeepGEMM 的适用范围,也能验证你对整个实现的掌握程度。
- **反馈和交流:**在深入研究过程中,你难免会有疑问或发现问题。这时欢迎在 DeepGEMM 的 GitHub Issues 区提出。社区和作者通常会给予解答。这既能解决你的问题,又为官方改进提供了参考。积极的交流也是开源学习的一部分。
最后,学习这样一个高性能库需要耐心和积累。DeepGEMM 作为开源项目,其价值不仅在于提供了现成的 FP8 GEMM 方案,更是一个难得的教学范例。通过以上循序渐进的路径,你将逐步弄清其总体设计、模块分工、接口用法和内核细节。当你真正理解并能够修改其中的代码时,你对 CUDA 优化和深度学习算子的认知将提升到新的高度。祝你学习愉快,在探索 DeepGEMM 的过程中收获满满!
参考链接¶
-
How can we integrate the DeepGEMM Fp8 GEMM implementation …
-
deep_gemm - vLLM
https://docs.vllm.ai/en/v0.10.2/api/vllm/utils/deep_gemm.html
-
Efficient GEMM Kernel Designs with Pipelining - SIGARCH
https://www.sigarch.org/efficient-gemm-kernel-designs-with-pipelining/
-
Mastering the NVIDIA® Tensor Memory Accelerator (TMA)
-
Fast Matrix-Multiplication with WGMMA on NVIDIA® Hopper™ GPUs
https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/
-
Deep Dive on CUTLASS Ping-Pong GEMM Kernel - PyTorch