原文: arXiv:2601.20706 | PDF
作者: Binglei Lou, Haoran Wu, Yao Lai, et al. (Imperial College London, University of Cambridge)
核心贡献: 针对扩散 LLM 采样优化 NPU 架构,提出 d-PLENA 向量-标量中心架构


摘要

扩散大语言模型 (dLLMs) 引入迭代去噪以实现并行 token 生成,但其采样阶段显示出与 GEMM 中心 transformer 层根本不同的特征。在现代 GPU 上的性能分析显示,采样可占总模型推理延迟的 70%——主要是由于词汇表范围 logits 的大量内存加载和写入、基于归约的 token 选择,以及迭代掩码更新。

本文提出 d-PLENA,一个向量-标量中心的架构扩展,支持高效的 dLLM 采样。采用轻量级非 GEMM 向量原语、原地内存复用策略和解耦的混合精度内存层次,相比 NVIDIA RTX A6000 GPU 在等效工艺节点下实现 2.53 倍加速


1. 问题定义:扩散 LLM 的采样瓶颈

1.1 扩散 LLM vs 自回归 LLM

自回归 (AR) LLM:

  • 顺序 token 生成
  • 内存带宽受限
  • 解码阶段并行度低

扩散 (d) LLM:

  • 并行 token 去噪
  • 摊销 token 依赖
  • 增加解码阶段算术强度

1.2 采样阶段成为瓶颈

延迟分解 图1: LLaDA 模型在 A6000 GPU 上的延迟分解,评估参数空间:batch size 1-32,去噪步数 1-32,生成长度 64-1024 tokens,块大小 8-64。

关键发现:

“虽然基于 transformer 的去噪阶段占浮点运算的大部分,但随后的采样阶段——执行词汇表范围归约、基于排名的选择和不规则内存访问——占端到端延迟的极大部分,在 MoE 和双 KV cache 配置下高达 71%。”

采样阶段特征:

  • 物化 logits 张量: $[B \times L \times V]$
  • $V$ 达到 120K-160K (LLaDA, DREAM)
  • 单 batch $L=64$ 需要 16-19 MB (FP16)
  • 多 in-flight batch 经常超出片上内存容量

1.3 GEMM-Centric NPU 的局限

“当代 NPU 设计深度优化于密集矩阵计算,但对扩散采样所需的控制密集、归约密集、内存不规则操作支持有限。”

不匹配之处:

操作类型 GEMM-Centric NPU dLLM 采样需求
主要运算 矩阵乘法 词汇表归约、Top-k
内存访问 规则、连续 不规则、随机
控制流 简单、数据并行 复杂、条件分支
精度 统一 FP16/BF16 混合精度

2. d-PLENA 架构

2.1 核心创新

d-PLENA 通过以下创新支持高效 dLLM 采样:

  1. 硬件友好的采样执行流: 原地计算和分阶段内存复用
  2. 轻量级非 GEMM ISA 原语: 加速 ArgMax、Top-k、掩码更新
  3. 解耦混合精度内存层次: 分离浮点和整数数据域

2.2 架构概览

┌─────────────────────────────────────────────────────────────┐
│                    d-PLENA 架构                              │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌─────────────────────────────────────────────────────┐   │
│  │              向量-标量执行单元                       │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐ │   │
│  │  │  ArgMax     │  │  Top-k      │  │  Softmax    │ │   │
│  │  │  单元       │  │  选择单元   │  │  单元       │ │   │
│  │  └─────────────┘  └─────────────┘  └─────────────┘ │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐ │   │
│  │  │  掩码更新   │  │  向量归约   │  │  排序网络   │ │   │
│  │  │  单元       │  │  单元       │  │             │ │   │
│  │  └─────────────┘  └─────────────┘  └─────────────┘ │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│  ┌─────────────────────────────────────────────────────┐   │
│  │              混合精度内存层次                        │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐ │   │
│  │  │  FP16/BF16  │  │  INT8/INT4  │  │  索引/掩码  │ │   │
│  │  │  数据域     │  │  数据域     │  │  数据域     │ │   │
│  │  └─────────────┘  └─────────────┘  └─────────────┘ │   │
│  │       ↑                  ↑                  ↑       │   │
│  │       └──────────────────┴──────────────────┘       │   │
│  │              统一地址空间,分离存储                  │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│  ┌─────────────────────────────────────────────────────┐   │
│  │              原地内存复用策略                        │   │
│  │  • 分阶段计算,复用内存缓冲区                        │   │
│  │  • 数值等价于标准实现                                │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
└─────────────────────────────────────────────────────────────┘

2.3 关键 ISA 原语

原语 功能 硬件实现
VARGMAX 向量 ArgMax 比较树网络
VTOPK Top-k 选择 部分排序网络
VSOFTMAX 向量 Softmax 指数查找表 + 归约
VMASKUPD 掩码 token 更新 条件移动单元
VREDUCE 向量归约 归约树

3. 实验结果

3.1 性能对比

配置 NVIDIA RTX A6000 d-PLENA 加速比
等效工艺 基准 2.53x 2.53x
不同 batch size 变化 优化 2.1-2.8x
不同去噪步数 线性增长 亚线性 2.3-2.6x
不同词汇表大小 超线性 优化 2.4-2.7x

3.2 资源利用率

指标 A6000 GPU d-PLENA 提升
片上 SRAM 利用率 45% 78% +73%
HBM 带宽利用率 62% 85% +37%
计算单元利用率 38% 72% +89%

3.3 数值正确性

“使用后综合 RTL 验证确认与当前 dLLM PyTorch 实现的功能等价。”

  • 与参考实现比特级等价
  • 通过 cycle-accurate 模拟器验证
  • 开源仿真和验证代码

4. 为什么对 AI 硬件重要

4.1 超越 GEMM-Centric 设计

传统 NPU 设计假设:

  • Transformer = GEMM (Attention + MLP)
  • 优化矩阵乘法即可

dLLM 揭示的新现实:

  • 采样阶段占 70% 延迟
  • 非 GEMM 操作成为瓶颈
  • 需要专用硬件支持

4.2 对下一代 NPU 的启示

1. 指令集扩展:

  • 添加向量归约原语
  • 支持 Top-k/ArgMax 硬件加速
  • 混合精度原生支持

2. 内存架构:

  • 更大的片上 SRAM
  • 支持不规则访问模式
  • 原地计算优化

3. 数据流优化:

  • 针对采样阶段的数据流
  • 减少内存往返
  • 流水线执行

4.3 与现有工作的关系

PLENA (基础架构):

  • 针对 AR LLM 优化
  • GEMM-centric 设计
  • d-PLENA 是其扩展

FlashAttention:

  • 优化 Attention 计算
  • 仍属于 GEMM 范畴
  • d-PLENA 解决采样问题

其他采样优化:

  • 主要软件层面
  • d-PLENA 硬件-算法协同

5. 局限与未来方向

5.1 当前局限

  • 评估平台: 主要基于模拟器
  • 模型范围: LLaDA, DREAM
  • 对比基准: RTX A6000

5.2 未来方向

短期:

  • 流片验证
  • 支持更多 dLLM 模型
  • 与 GEMM 单元协同优化

中期:

  • 扩展到其他生成模型 (Diffusion 图像/视频)
  • 支持更复杂的采样算法
  • 自适应精度调整

长期:

  • 通用非 GEMM 加速器
  • 软件-硬件协同设计工具
  • 自动化架构探索

6. 总结

d-PLENA 代表了 NPU 架构设计的重要扩展:

  1. 识别新瓶颈: dLLM 采样占 70% 延迟
  2. 架构创新: 向量-标量中心设计
  3. 专用原语: ArgMax, Top-k, 掩码更新硬件加速
  4. 混合精度: 解耦的内存层次
  5. 显著加速: 2.53x 相比 GPU

对于 AI 硬件设计,d-PLENA 表明:

  • GEMM 不是全部: 非矩阵运算也需要硬件优化
  • 工作负载演变: 随着模型架构演进,硬件需适应
  • 专用 vs 通用: 在效率和灵活性间平衡
  • 开源验证: 开源仿真器促进研究

随着扩散模型在语言、图像、视频等领域的应用扩展,d-PLENA 的设计原则将在更广泛的生成式 AI 硬件中发挥重要作用。


参考文献

  1. Lou, B., et al. (2026). Beyond GEMM-Centric NPUs: Enabling Efficient Diffusion LLM Sampling. arXiv:2601.20706.
  2. Nie, J., et al. (2024). LLaDA: Large Language Diffusion with mAsking. arXiv.
  3. Xiao, C., et al. (2024). PLENA: A Platform for Neural Network Acceleration. arXiv.
  4. Dao, T. (2024). FlashAttention-2: Faster Attention with Better Parallelism. ICLR.