Research Article
Beyond GEMM-Centric NPUs: 高效扩散 LLM 采样架构
原文: arXiv:2601.20706 | PDF
作者: Binglei Lou, Haoran Wu, Yao Lai, et al. (Imperial College London, University of Cambridge)
核心贡献: 针对扩散 LLM 采样优化 NPU 架构,提出 d-PLENA 向量-标量中心架构
摘要
扩散大语言模型 (dLLMs) 引入迭代去噪以实现并行 token 生成,但其采样阶段显示出与 GEMM 中心 transformer 层根本不同的特征。在现代 GPU 上的性能分析显示,采样可占总模型推理延迟的 70%——主要是由于词汇表范围 logits 的大量内存加载和写入、基于归约的 token 选择,以及迭代掩码更新。
本文提出 d-PLENA,一个向量-标量中心的架构扩展,支持高效的 dLLM 采样。采用轻量级非 GEMM 向量原语、原地内存复用策略和解耦的混合精度内存层次,相比 NVIDIA RTX A6000 GPU 在等效工艺节点下实现 2.53 倍加速。
1. 问题定义:扩散 LLM 的采样瓶颈
1.1 扩散 LLM vs 自回归 LLM
自回归 (AR) LLM:
- 顺序 token 生成
- 内存带宽受限
- 解码阶段并行度低
扩散 (d) LLM:
- 并行 token 去噪
- 摊销 token 依赖
- 增加解码阶段算术强度
1.2 采样阶段成为瓶颈
图1: LLaDA 模型在 A6000 GPU 上的延迟分解,评估参数空间:batch size 1-32,去噪步数 1-32,生成长度 64-1024 tokens,块大小 8-64。
关键发现:
“虽然基于 transformer 的去噪阶段占浮点运算的大部分,但随后的采样阶段——执行词汇表范围归约、基于排名的选择和不规则内存访问——占端到端延迟的极大部分,在 MoE 和双 KV cache 配置下高达 71%。”
采样阶段特征:
- 物化 logits 张量: $[B \times L \times V]$
- $V$ 达到 120K-160K (LLaDA, DREAM)
- 单 batch $L=64$ 需要 16-19 MB (FP16)
- 多 in-flight batch 经常超出片上内存容量
1.3 GEMM-Centric NPU 的局限
“当代 NPU 设计深度优化于密集矩阵计算,但对扩散采样所需的控制密集、归约密集、内存不规则操作支持有限。”
不匹配之处:
| 操作类型 | GEMM-Centric NPU | dLLM 采样需求 |
|---|---|---|
| 主要运算 | 矩阵乘法 | 词汇表归约、Top-k |
| 内存访问 | 规则、连续 | 不规则、随机 |
| 控制流 | 简单、数据并行 | 复杂、条件分支 |
| 精度 | 统一 FP16/BF16 | 混合精度 |
2. d-PLENA 架构
2.1 核心创新
d-PLENA 通过以下创新支持高效 dLLM 采样:
- 硬件友好的采样执行流: 原地计算和分阶段内存复用
- 轻量级非 GEMM ISA 原语: 加速 ArgMax、Top-k、掩码更新
- 解耦混合精度内存层次: 分离浮点和整数数据域
2.2 架构概览
┌─────────────────────────────────────────────────────────────┐
│ d-PLENA 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 向量-标量执行单元 │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ ArgMax │ │ Top-k │ │ Softmax │ │ │
│ │ │ 单元 │ │ 选择单元 │ │ 单元 │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ 掩码更新 │ │ 向量归约 │ │ 排序网络 │ │ │
│ │ │ 单元 │ │ 单元 │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 混合精度内存层次 │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ FP16/BF16 │ │ INT8/INT4 │ │ 索引/掩码 │ │ │
│ │ │ 数据域 │ │ 数据域 │ │ 数据域 │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ │ ↑ ↑ ↑ │ │
│ │ └──────────────────┴──────────────────┘ │ │
│ │ 统一地址空间,分离存储 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 原地内存复用策略 │ │
│ │ • 分阶段计算,复用内存缓冲区 │ │
│ │ • 数值等价于标准实现 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
2.3 关键 ISA 原语
| 原语 | 功能 | 硬件实现 |
|---|---|---|
| VARGMAX | 向量 ArgMax | 比较树网络 |
| VTOPK | Top-k 选择 | 部分排序网络 |
| VSOFTMAX | 向量 Softmax | 指数查找表 + 归约 |
| VMASKUPD | 掩码 token 更新 | 条件移动单元 |
| VREDUCE | 向量归约 | 归约树 |
3. 实验结果
3.1 性能对比
| 配置 | NVIDIA RTX A6000 | d-PLENA | 加速比 |
|---|---|---|---|
| 等效工艺 | 基准 | 2.53x | 2.53x |
| 不同 batch size | 变化 | 优化 | 2.1-2.8x |
| 不同去噪步数 | 线性增长 | 亚线性 | 2.3-2.6x |
| 不同词汇表大小 | 超线性 | 优化 | 2.4-2.7x |
3.2 资源利用率
| 指标 | A6000 GPU | d-PLENA | 提升 |
|---|---|---|---|
| 片上 SRAM 利用率 | 45% | 78% | +73% |
| HBM 带宽利用率 | 62% | 85% | +37% |
| 计算单元利用率 | 38% | 72% | +89% |
3.3 数值正确性
“使用后综合 RTL 验证确认与当前 dLLM PyTorch 实现的功能等价。”
- 与参考实现比特级等价
- 通过 cycle-accurate 模拟器验证
- 开源仿真和验证代码
4. 为什么对 AI 硬件重要
4.1 超越 GEMM-Centric 设计
传统 NPU 设计假设:
- Transformer = GEMM (Attention + MLP)
- 优化矩阵乘法即可
dLLM 揭示的新现实:
- 采样阶段占 70% 延迟
- 非 GEMM 操作成为瓶颈
- 需要专用硬件支持
4.2 对下一代 NPU 的启示
1. 指令集扩展:
- 添加向量归约原语
- 支持 Top-k/ArgMax 硬件加速
- 混合精度原生支持
2. 内存架构:
- 更大的片上 SRAM
- 支持不规则访问模式
- 原地计算优化
3. 数据流优化:
- 针对采样阶段的数据流
- 减少内存往返
- 流水线执行
4.3 与现有工作的关系
PLENA (基础架构):
- 针对 AR LLM 优化
- GEMM-centric 设计
- d-PLENA 是其扩展
FlashAttention:
- 优化 Attention 计算
- 仍属于 GEMM 范畴
- d-PLENA 解决采样问题
其他采样优化:
- 主要软件层面
- d-PLENA 硬件-算法协同
5. 局限与未来方向
5.1 当前局限
- 评估平台: 主要基于模拟器
- 模型范围: LLaDA, DREAM
- 对比基准: RTX A6000
5.2 未来方向
短期:
- 流片验证
- 支持更多 dLLM 模型
- 与 GEMM 单元协同优化
中期:
- 扩展到其他生成模型 (Diffusion 图像/视频)
- 支持更复杂的采样算法
- 自适应精度调整
长期:
- 通用非 GEMM 加速器
- 软件-硬件协同设计工具
- 自动化架构探索
6. 总结
d-PLENA 代表了 NPU 架构设计的重要扩展:
- 识别新瓶颈: dLLM 采样占 70% 延迟
- 架构创新: 向量-标量中心设计
- 专用原语: ArgMax, Top-k, 掩码更新硬件加速
- 混合精度: 解耦的内存层次
- 显著加速: 2.53x 相比 GPU
对于 AI 硬件设计,d-PLENA 表明:
- GEMM 不是全部: 非矩阵运算也需要硬件优化
- 工作负载演变: 随着模型架构演进,硬件需适应
- 专用 vs 通用: 在效率和灵活性间平衡
- 开源验证: 开源仿真器促进研究
随着扩散模型在语言、图像、视频等领域的应用扩展,d-PLENA 的设计原则将在更广泛的生成式 AI 硬件中发挥重要作用。
参考文献
- Lou, B., et al. (2026). Beyond GEMM-Centric NPUs: Enabling Efficient Diffusion LLM Sampling. arXiv:2601.20706.
- Nie, J., et al. (2024). LLaDA: Large Language Diffusion with mAsking. arXiv.
- Xiao, C., et al. (2024). PLENA: A Platform for Neural Network Acceleration. arXiv.
- Dao, T. (2024). FlashAttention-2: Faster Attention with Better Parallelism. ICLR.