基于 nano-vLLM 学习大模型推理关键功能

张开发
2026/4/16 6:25:56 15 分钟阅读

分享文章

基于 nano-vLLM 学习大模型推理关键功能
KV 缓存Prefix KV Cache / Paged KV Cache高性能编译与执行优化Torch Compilation、Triton、CUDA Graph张量并行Tensor Parallelism该框架极具入门学习价值本文将先介绍 nano-vLLM 的基本组成架构再对部分核心技术要点展开深入解析。3. 系统架构nano-vLLM 的架构非常有层次感。3.1. 整体架构概览图 2, 来自https://deepwiki.com/GeeeekExplorer/nano-vllm三层结构接口层User Interface Layer推理引擎中控层Inference Engine Layer显存管理和模型执行层Memory Management Model Execution Layer3.2. 类层面架构从类设计层面观察 nano-vLLM 的架构。图 3上图中四种颜色代表系统的四个组成部分浅蓝色入口和推理引擎中控层浅绿色模型推理浅红色KV Cache 管理浅紫色权重加载和矩阵计算的封装3.3. 源码层面划分源码规划上也较为简洁。目录结构如下12345nanovllm/├── engine├── layers├── models└── utilsengine引擎的入口、中控同时 KV Cache 比较简单代码也放在这个目录下。layers模型推理的通用组件内部包括linear、layernorm、rotary_embedding、attention、activation 等基础功能的封装可以被不同模型使用。models模型的实现依赖 layers 的组件实现不同模型的推理。utils不同层都可能会用到的工具函数。4. 连续批处理4.1. 概念理解1定义连续批处理 (Continuous Batching)是一种迭代级Iteration-level的调度策略。它以“Token 生成步骤”为调度粒度。通过动态地在每一轮迭代中替换已完成的任务消除了由于生成长度不一导致的 GPU 计算气泡极大地提升了系统的吞吐量。2朴素理解一个请求需要执行多轮不同请求需要执行的轮数不同系统一轮最多只能同时执行一批 N 个请求当一个批次里的请求参差不齐的完成时每完成一个请求就将其用新请求替代掉。对比传统批处理和连续批处理传统批处理 (Static Batching)必须等待 Batch 中生成序列最长的那个请求完成整个 Batch 才会释放。在此期间生成序列短请求完成后槽位会空转。连续批处理 (Continuous Batching)请求完成即退出新请求立即补位槽位始终满载。4.2. 最基础的连续批处理最简单的连续批处理不考虑 prefill 和 decode 的差异示例代码import time import threading import queue import random # 1. 初始化线程安全的等待队列 waiting_queue queue.Queue() MAX_BATCH_SIZE 3 # --- 模拟用户请求线程 (生产者) --- def user_request_producer(): request_id 1 while True: # 模拟用户随机到达每 1~2 秒来一个新请求 time.sleep(random.uniform(1, 2)) # 每个请求需要的 Token 长度随机3到8之间 req {id: fREQ-{request_id}, remain: random.randint(3, 8)} waiting_queue.put(req) print(f\n[用户端] 送入新请求: {req[id]} (预计长度: {req[remain]})) request_id 1 if request_id 5: break # --- 核心推理循环 (消费者/执行器) --- def inference_loop(): running_batch [] print(--- 推理引擎已启动 ---) iteration 0 while True: # A. 补位逻辑只要 Batch 没满且队列里有货就拉进来 while len(running_batch) MAX_BATCH_SIZE: try: # 使用 blockFalse如果队列空了直接报错进 except不阻塞推理逻辑 new_req waiting_queue.get(blockFalse) running_batch.append(new_req) print(f [调度] {new_req[id]} 进入 Batch) except queue.Empty: break # B. 推理逻辑如果当前 Batch 有任务就执行一次 Step if running_batch: iteration 1 print(*20 f{iteration} *20) # 模拟 GPU 推理耗时 (Step 耗时) time.sleep(1.2) # 当前 Batch 状态展示 active_ids [f{r[id]}(剩{r[remain]-1}) for r in running_batch] print(f[GPU推理] 处理中: {active_ids}) # 每一个请求的剩余长度减 1 finished_this_step [] for req in running_batch: req[remain] - 1 if req[remain] 0: finished_this_step.append(req) # C. 剔除逻辑做完的立刻踢出下一轮循环开头就会有新请求补进来 for req in finished_this_step: print(f [完成] {req[id]} 生成完毕释放位置) running_batch.remove(req) else: # 如果 Batch 和 队列都空了稍微歇会避免 CPU 空转 time.sleep(0.5) # --- 启动程序 --- if __name__ __main__: # 启动用户请求线程 t threading.Thread(targetuser_request_producer, daemonTrue) t.start() # 主线程执行推理循环 try: inference_loop() except KeyboardInterrupt: print(\n服务已停止)核心逻辑存储结构代码的核心有两个队列waiting_queue 负责存储请求线程不断接收到的新请求running_queue 负责存储已经运行但还没有结束的请求。迭代循环生产者持续往 waiting_queue 写入新请求迭代循环持续从 waiting_queue 获取新请求加入到 running_queue同时清理 running_queue 里已经完成的请求。4.3. prefill 优先的连续批处理prefill 优先的批处理需要区分 prefll 和 decode优先处理新请求示例代码import time import queue import random import threading # 核心队列 waiting_queue queue.Queue() running_queue [] MAX_BATCH_SIZE 4 def user_request_producer(): 修改点模拟爆发式请求到达以触发多请求 Prefill # 第一波爆发式到达 (3个请求同时进入队列) print(\n[用户] --- 爆发式请求到达 (3个请求) ---) for i in range(1, 4): req {id: fREQ-{i}, remain: random.randint(2, 5)} waiting_queue.put(req) print(f[用户] 请求 {req[id]} 进入等待队列) # 延迟一会儿再来第二波单点请求 time.sleep(5) print(\n[用户] --- 延迟请求到达 (1个请求) ---) req {id: REQ-4, remain: 3} waiting_queue.put(req) print(f[用户] 请求 {req[id]} 进入等待队列) def inference_loop(): print(--- 连续批处理引擎多请求 Prefill 模式 ---) iteration 0 while True: current_batch [] is_prefill_stage False # 1. 调度构建当前批次 # 只要 waiting_queue 非空就尽可能填满 MAX_BATCH_SIZE if not waiting_queue.empty(): is_prefill_stage True while not waiting_queue.empty() and len(current_batch) MAX_BATCH_SIZE: req waiting_queue.get() current_batch.append(req) elif running_queue: is_prefill_stage False current_batch list(running_queue) if not current_batch: time.sleep(0.5) continue # 2. 执行模拟推理 iteration 1 print(f\n{*15} Iteration {iteration} {*15}) if is_prefill_stage: print(f[PREFILL] 批量生成中: {[r[id] for r in current_batch]}) time.sleep(1.5) else: print(f[DECODE ] 批量生成中: {[f{r[id]}(剩{r[remain]}) for r in current_batch]}) time.sleep(0.4) # 3. 统一状态更新 for req in current_batch: req[remain] - 1 # 4. 统一判断生命周期 # 注意为了避免在遍历列表时删除元素我们先收集要删除的对象 to_remove_from_running [] for req in current_batch: if req[remain] 0: print(f [完成] {req[id]} 退出系统) if req in running_queue: to_remove_from_running.append(req) else: if is_prefill_stage: running_queue.append(req) print(f - {req[id]} Prefill 完成转入 running_queue) else: pass # 真正的从 running_queue 移除 for req in to_remove_from_running: running_queue.remove(req) if __name__ __main__: t threading.Thread(targetuser_request_producer, daemonTrue) t.start() try: inference_loop() except KeyboardInterrupt: pass叠加上 prefill 优先之后的连续批处理代码也较为简单主要是维护三个变量waiting_queue、running_queue、current_batch。5. KV Cache5.1. 概念理解5.1.1. KV Cache 的用途KV Cache 有两层用途。一是用在同一个请求的 Decode 阶段复用之前已经计算过的 KV 结果以避免重复计算二是用在不同请求之间使具有相同前缀的请求可以共享一部分 KV 数据这就是 Prefix KV Cache。5.1.2. PagedAttention 技术在 Cache 的存储层面PagedAttention 实现了显存的按需申请。由于 KV Cache 空间不再一次性预分配请求序列对应的物理地址是离散的。PagedAttention 的核心在于它能够直接读取这些物理离散的块来完成注意力计算这背后实现了一层从“逻辑连续地址”到“物理离散地址”的映射。对于没有接触过非 PagedAttention 实现的读者来说这种设计似乎理所当然按需申请、分页管理、地址映射、局部性原理——这些都是计算机科学中非常常规的思维甚至很难想到不这么写的理由。那么为什么 PagedAttention 会被认为是一项里程碑式的先进技术呢首先在 PagedAttention 出现之前业界普遍认为 KV Cache 在显存中必须物理连续否则会因访存不连续导致性能大幅下降。其次当时的注意力算子如标准的 FlashAttention并不支持二次寻址映射。PagedAttention 证明了即便物理存储不连续性能依然可以保持极高。其代码实现最关键的点在于重构了 CUDA 内核使其原生支持 KV Cache 二次寻址。一个序列的 KV Cache 不需要物理连续也正是不同序列间能够灵活复用 Prefix KV Cache 的技术前提。总的来说虽然分页虚拟内存在 CPU 领域是常识但在 GPU 算子领域其发展相对缓慢。实现一套既能分页管理、又不损失算力利用率的 Attention Kernel 是 PagedAttention 的核心所在。5.2. Prefix KV Cache 的实现Cache 的管理较为简单只有 BlockManager 类负责维护显存池各个 block 的的状态。5.2.1. 功能细节使用 hash 来识别是否有可复用前缀以 block 为基本单元链式 hash每个 block 的 hash 计算输入为前序 block 的 hash 值加上本 block 的 token id每一个 block 有对应的 meta 信息对象记录 block 被复用的引用计数确保复用时不会被释放为避免 hash 碰撞出现错误block meta 信息还需要记录原始的 token id在获取 KV Cache 空间时需要考虑是否跨 block5.2.2. 内存池在进程启动时一次性申请内存池的空间kv_cache torch.empty( 2, # K 和 V num_layers, # 层数 num_blocks, # 总块数 block_size, # 每块 token 数 num_kv_heads // tp_size, # KV head 数考虑张量并行 head_dim # head 维度 )上述申请显存代码中的 num_blocks 是根据可用于 KV Cache 的显存算出来的block_bytes 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks int(total * config.gpu_memory_utilization - used - peak current) // block_bytes上述代码中的 block_bytes 并不是指 block 的大小而是把计算 blocks 数的所有除数乘到了一起除数包括block 大小、k 和 v、模型层数。total * config.gpu_memory_utilization - used - peak current 这部分则是根据最高的显存利用率算出来可用显存减去当前模型加载完后使用了的部分再减去模型预热时使用的激活显存peak - current。申请到内存池后按层共享视图给各个层的 Attention 对象代码看起来比较 tricky但在 python 里倒比较常见for module in self.model.modules(): if hasattr(module, k_cache) and hasattr(module, v_cache): module.k_cache self.kv_cache[0, layer_id] module.v_cache self.kv_cache[1, layer_id] layer_id 1遍历模型中的所有 nn.Module 子模块通过检查是否存在 k_cache 和 v_cache 属性来识别 Attention 层。对于每个 Attention 层将其 k_cache 和 v_cache 属性替换为指向全局 KV Cache 显存池的张量视图这样所有层共享同一块连续的显存空间但每层只能访问自己对应的切片。5.2.3. KV Cache 写入在 attention 子层的 forward 前做 KV Cache 的写入使用的 store_kvcache_kernel 函数是 triton.jit 实现的代码也比较简洁triton.jit def store_kvcache_kernel( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx tl.program_id(0) slot tl.load(slot_mapping_ptr idx) if slot -1: return key_offsets idx * key_stride tl.arange(0, D) value_offsets idx * value_stride tl.arange(0, D) key tl.load(key_ptr key_offsets) value tl.load(value_ptr value_offsets) cache_offsets slot * D tl.arange(0, D) tl.store(k_cache_ptr cache_offsets, key) tl.store(v_cache_ptr cache_offsets, value) def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): N, num_heads, head_dim key.shape D num_heads * head_dim assert key.stride(-1) 1 and value.stride(-1) 1 assert key.stride(1) head_dim and value.stride(1) head_dim assert k_cache.stride(1) D and v_cache.stride(1) D assert slot_mapping.numel() N store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)使用 stride 函数来确认显存是否连续因为在 store_kvcache_kernel 的实现里会按照显存连续来读取指定位置的值。显存连续和不连续的例子import torch key torch.randn(2, 3, 4) print(key.stride()) key_t key.transpose(1, 2) print(转置后非连续:, key_t.stride()) key_t key_t.contiguous() # 重新分配使其连续 print(contiguous 后:, key_t.stride()) # 输出 # (12, 4, 1) # 转置后非连续: (12, 1, 4) # contiguous后: (12, 3, 1)6. cuda graph6.1. 概念理解CUDA Graph 是一种将一系列 CUDA 操作录制成图的技术在重复执行的固定操作序列场景下可以显著提升推理性能主要基于这几方面减少 CPU 与 GPU 之间的频繁同步和指令下发开销降低传统独立操作带来的控制流交互损耗减少执行过程中的 CPU 干预GPU 自主批量执行图内操作最大化 GPU 利用率降低延迟、提升吞吐。规避多次独立 CUDA Kernel 的启动固定开销多个 Kernel 打包后仅需一次调度触发大幅提升小 Kernel 密集场景的执行效率可配合显存池实现显存资源复用减少 “少量多次” 显存申请 / 释放的开销同时驱动会基于图内显存访问模式优化带宽利用率CUDA 驱动可获取操作序列的全局视图基于完整的依赖关系进行全局优化如 Kernel 顺序调整、资源合并等6.2. 功能细节录制时使用的张量内存地址在重放时必须保持不变也就是后面多次 replay 都会使用捕获时申请的变量空间捕获后的 graph 对象记录在成员变量里供下次推理时选择重放时选择比请求 batch size 大的最小 graph batch size捕获时不同的 batch size 共享相同的静态显存空间并让多个批次共享显存池使得虽然有多个 batch size但只会使用 Max Batch Size 的显存空间6.3. 示例代码import torch import torch.nn as nn # 1. 基础配置 device cuda D 512 # 维度 graph_bs [1, 8, 32] # 预定义的桶分桶尺寸 NUM_LAYERS 100 # 极深模型增加 Kernel 数量以放大 Graph 优势 iters 50 # 性能测试迭代次数 max_bs max(graph_bs) # 2. 定义深层模型 (产生约 400 个 Kernel) class UltraDeepModel(nn.Module): def __init__(self): super().__init__() self.blocks nn.ModuleList() for _ in range(NUM_LAYERS): block nn.ModuleDict({ ln: nn.LayerNorm(D).to(device), linear: nn.Linear(D, D).to(device), act: nn.ReLU() }) self.blocks.append(block) def forward(self, x): for block in self.blocks: identity x x block[ln](x) x block[linear](x) x block[act](x) x x identity return x model UltraDeepModel().eval() # 3. 静态缓冲区准备 static_input torch.empty(max_bs, D, devicedevice) static_output torch.empty(max_bs, D, devicedevice) graphs {} graph_pool None # 4. 录制阶段 (从大到小共享内存池) print(f--- 开始录制分桶 CUDA Graphs ---) for bs in reversed(sorted(graph_bs)): current_input static_input[:bs] # Warmup for _ in range(5): _ model(current_input) g torch.cuda.CUDAGraph() with torch.cuda.graph(g, poolgraph_pool): static_output[:bs] model(current_input) if graph_pool is None: graph_pool g.pool() graphs[bs] g print(f✅ 已录制桶 BS{bs}) # 5. 辅助函数根据实际 BS 匹配最近的桶 def get_bucket_bs(actual_bs): for b in sorted(graph_bs): if actual_bs b: return b return None # 6. 性能对比测试 (包含 Padding 逻辑) def benchmark(actual_test_bs7): print(f\n--- 性能测试开始: 实际请求 BS{actual_test_bs} ---) # 生成测试数据 test_data torch.randn(actual_test_bs, D, devicedevice) start_event torch.cuda.Event(enable_timingTrue) end_event torch.cuda.Event(enable_timingTrue) # --- 段落 A: Standard Eager Mode (直接跑 7 个) --- torch.cuda.nvtx.range_push(Eager_Mode) start_event.record() for _ in range(iters): _ model(test_data) end_event.record() torch.cuda.synchronize() eager_time start_event.elapsed_time(end_event) torch.cuda.nvtx.range_pop() # --- 段落 B: CUDA Graph Mode (Padding 对齐到 8) --- # 1. 路由逻辑 bucket_bs get_bucket_bs(actual_test_bs) if bucket_bs is None: print(f❌ 错误: 实际 BS{actual_test_bs} 超过了最大分桶 {max_bs}) return torch.cuda.nvtx.range_push(fGraph_Mode_Bucket_{bucket_bs}) # 2. 数据对齐 (Padding): 将 7 条数据拷入 8 的静态区域 # static_input 的前 7 行被覆盖第 8 行保持不变即 Padding 位 static_input[:actual_test_bs].copy_(test_data) start_event.record() for _ in range(iters): # 3. 重放分桶 8 的图 graphs[bucket_bs].replay() end_event.record() torch.cuda.synchronize() graph_time start_event.elapsed_time(end_event) # 4. 结果截断 (Slicing): 从静态区拿回前 7 条 final_res static_output[:actual_test_bs] torch.cuda.nvtx.range_pop() # 打印结果 print(f匹配到的桶: {bucket_bs} (Padding 浪费率: {(bucket_bs-actual_test_bs)/bucket_bs*100:.1f}%)) print(f{Mode:20} | {Avg Time (ms):15}) print(- * 40) print(f{Eager Mode:20} | {eager_time/iters:15.4f}) print(f{Graph Mode:20} | {graph_time/iters:15.4f}) print(- * 40) print(f 加速比: {eager_time/graph_time:.2f}x) print(f最终输出形状: {final_res.shape}) if __name__ __main__: # 测试不同的输入 BS benchmark(actual_test_bs7) # 触发对齐到 8 benchmark(actual_test_bs1) # 精确匹配到 1执行nsys profile --tracecuda,osrt,nvtx python3 cu2.py输出Collecting data... --- 开始录制分桶 CUDA Graphs --- ✅ 已录制桶 BS32 ✅ 已录制桶 BS8 ✅ 已录制桶 BS1 --- 性能测试开始: 实际请求 BS7 --- 匹配到的桶: 8 (Padding 浪费率: 12.5%) Mode | Avg Time (ms) ---------------------------------------- Eager Mode | 7.9331 Graph Mode | 1.0136 ---------------------------------------- 加速比: 7.83x 最终输出形状: torch.Size([7, 512]) --- 性能测试开始: 实际请求 BS1 --- 匹配到的桶: 1 (Padding 浪费率: 0.0%) Mode | Avg Time (ms) ---------------------------------------- Eager Mode | 8.1803 Graph Mode | 0.8011 ---------------------------------------- 加速比: 10.21x 最终输出形状: torch.Size([1, 512])查看 nsys图 4可以看到在 cuda graph 的时候SM 使用更充分。6.4. QA1为什么推理时 cuda graph 的选择要采用向上对齐的分桶策略即批次相等或稍大的而不是选择批次最大的虽然在建图Capture阶段系统会按照最大批次Max Batch Size预先申请并锁定静态显存空间此时即便选择最大批次执行也不会产生额外的显存容量浪费但会引入以下两个维度的性能损耗显存带宽的无效占用大模型推理尤其是 Decoding 阶段属于典型的访存密集型任务其瓶颈在于模型权重从显存到计算单元的搬运速度。即便大部分批次位置是 Padding空数据CUDA Graph 依然会严格执行录制时的内存寻址定义搬运完整批次的数据。使用过大的批次会导致 GPU 浪费极其宝贵的带宽去搬运“无效数据”从而增加单次推理的耗时推高推理延迟Latency。计算资源的无效占用GPU 调度器会根据图定义的规模预分配硬件资源如 SM 核心、寄存器、共享显存等。虽然 Padding 部分的计算逻辑极快但这些资源在整个 CUDA Graph 执行完成前无法被释放。这会导致 GPU 硬件处于“虚假繁忙”状态阻塞了其他潜在任务如多流并行等获取硬件资源削弱了系统整体的并发吞吐能力Throughput。7. Torch Compilation7.1. 概念理解torch.compile 能够将 PyTorch 张量计算相关的 Python 逻辑转化为更高效的中间表示在 CUDA 设备上通常是 Triton 内核代码也支持原生 CUDA 内核。相较于传统的即时执行Eager Mode这种方式通过优化计算内核本身带来显著的运行效率提升此外当输入张量形状、数据类型固定时torch.compile 还会自动启用 CUDA Graph 优化进一步放大性能收益。在 torch.compile 问世之前PyTorch 开发者若想追求高性能仅有两种核心选择一是使用 Eager Mode 接受其原生性能上限二是手动编写 Triton 或 CUDA 底层内核代码该方式开发门槛高、周期长、维护成本高。而有了 torch.compile 后开发者只需编写简洁易懂的 PyTorch Python 业务逻辑无需关注底层硬件适配与内核实现即可获得接近手写 Triton/CUDA 的优异性能大幅平衡了开发效率与运行性能。7.2. 使用方法应用 torch.compile 非常简单核心有两类使用方式装饰器方式在 PyTorch 函数上直接添加 torch.compile 装饰器定义时即完成编译声明显式调用方式通过 compiled_obj torch.compile(target) 显式编译目标对象后续调用 compiled_obj 即可使用优化后的逻辑另外可以对模型实例的直接编译对于 PyTorch 模型nn.Module 子类实例可直接传入 torch.compile 完成整体编译无需单独修饰 forward 方法。示例代码方式 1装饰器方式适用于函数 / 模型方法import torch import torch.nn as nn # 对普通PyTorch函数使用装饰器 torch.compile def my_tensor_func(x, y): return torch.matmul(x, y) torch.relu(y) # 对模型的forward方法使用装饰器 class MyModel(nn.Module): torch.compile # 修饰forward方法自动编译模型推理逻辑 def forward(self, x): return nn.Linear(10, 20)(x)方式 2显式调用方式适用于函数 / 模型灵活性更高import torch import torch.nn as nn # 显式编译普通函数 def my_tensor_func(x, y): return torch.matmul(x, y) torch.relu(y) compiled_func torch.compile(my_tensor_func) # 生成编译后的函数 output compiled_func(torch.randn(32, 10), torch.randn(10, 20)) # 调用编译后的函数 # 显式编译模型与方式3本质一致更强调“先编译后使用”的显式流程 class MyModel(nn.Module): def forward(self, x): return nn.Linear(10, 20)(x) model MyModel() compiled_model torch.compile(model) # 直接编译整个模型实例 output compiled_model(torch.randn(32, 10)) # 调用编译后的模型方式 3直接编译模型实例深度学习中最常用简化写法import torch import torch.nn as nn class MyModel(nn.Module): def forward(self, x): return nn.Linear(10, 20)(x) # 直接编译模型实例一步到位无需装饰器最简洁常用 model torch.compile(MyModel()) output model(torch.randn(32, 10))7.3. 性能对比下面以一个简单的例子对比 Eager Mode 和 Compiled Modeimport torch import time # 确保使用的是 GPU device cuda if torch.cuda.is_available() else cpu if device cpu: print(警告CUDA 不可用将使用 CPU 运行torch.compile 的优势在 GPU 上最明显。) def complex_operation_eager(x, y): z x * y z z x z torch.relu(z) return z.sum() torch.compile def complex_operation_graph(x, y): z x * y z z x z torch.relu(z) return z.sum() # 1. 准备数据 x torch.randn(10000, 10000, devicedevice) y torch.randn(10000, 10000, devicedevice) # 2. 热身 (Warm up) print(正在编译并进行多次热身以稳定 GPU 状态...) # 增加热身循环 for i in range(3): complex_operation_graph(x, y) if i 0: print(- 首次编译完成正在进行后续预热...) torch.cuda.synchronize() print(预热完毕开始正式测试。) def benchmark(func, x, y, label, iterations100): start_event torch.cuda.Event(enable_timingTrue) end_event torch.cuda.Event(enable_timingTrue) start_event.record() for _ in range(iterations): func(x, y) end_event.record() torch.cuda.synchronize() elapsed_time_ms start_event.elapsed_time(end_event) avg_time_s (elapsed_time_ms / 1000) / iterations print(f{label} 平均耗时: {avg_time_s:.6f} 秒) return avg_time_s # 3. 执行测试并计算加速比 with torch.no_grad(): print(- * 30) eager_time benchmark(complex_operation_eager, x, y, Eager Mode ) compile_time benchmark(complex_operation_graph, x, y, Compiled Mode) print(- * 30) # 计算加速比逻辑 speedup eager_time / compile_time improvement (speedup - 1) * 100 print(f性能提升结果:) print(f加速比 (Speedup): {speedup:.2f}x) print(f运行速度提升了: {improvement:.1f}%)运行输出正在编译并进行多次热身以稳定 GPU 状态... - 首次编译完成正在进行后续预热... 预热完毕开始正式测试。 ------------------------------ Eager Mode 平均耗时: 0.005690 秒 Compiled Mode 平均耗时: 0.001528 秒 ------------------------------ 性能提升结果: 加速比 (Speedup): 3.72x 运行速度提升了: 272.3%性能有数倍的提升。7.4. QA1既然 torch.compile 有这么大的好处为什么不能给所有的张量操作函数都加上 torch.compile有几方面的原因首先存在编译开销会导致首次运行显著变慢然后编译器生成的 Triton 内核或其他后端代码通常是针对特定张量形状、数据类型和设备配置优化的如果输入张量的这些属性频繁变化会反复触发重新编译即 “编译缓存失效”反而抵消性能收益第三torch.compile 自身会带来额外的显存开销用于存储编译后的中间表示、内核缓存等过多无差别使用可能导致显存不足OOM最后并非所有代码都能被成功图化优化如果张量操作中调用了非 PyTorch 原生的第三方库或纯 Python 原生逻辑会导致计算图中断此时编译器无法继续优化后续逻辑还需要将控制权交回给 Python 解释器产生不必要的上下文切换开销可能导致负优化。2torch.compile 支持生成原生 CUDA 内核代码但通常来说编译器自动生成的通用原生 CUDA 代码优化粒度不够精细而 Triton 内置了极强的 Autotuning自动调优能力针对深度学习张量计算场景做了深度适配因此在绝大多数深度学习任务中Triton 内核的性能通常优于自动生成的原生 CUDA 内核。8. Torch Compilation、Trition、CUDA Graph 三者的区别和联系8.1. 核心区别1Torch CompilationPyTorch 高层一站式性能优化入口用户态抽象接口作为面向开发者的顶级优化封装torch.compile 无需开发者关注底层硬件细节与优化实现其核心定位是对 PyTorch 张量计算逻辑函数 /nn.Module 模型进行端到端自动优化屏蔽了底层内核生成与执行优化的复杂性是绝大多数 PyTorch 开发者的首选性能优化工具。2Triton高性能 GPU 内核专用 DSLTriton 既是开发者手动编写高性能内核的领域专用语言DSL也是 torch.compile 自动化生成代码的核心目标后端。其中triton.jit 是 Triton 框架提供的即时编译装饰器定位为高性能跨平台 GPU 内核的手动开发入口抽象层级低于 torch.compile、高于原生 CUDA C。它允许开发者以 Python 风格语法编写 GPU 内核逻辑无需手动处理线程调度、寄存器分配等底层细节最终编译为高效 GPU 机器码用于满足定制化算子的高性能需求。3CUDA GraphGPU 底层静态任务流执行优化技术CUDA Graph 是一种静态任务流调度技术旨在消除主机端Host与设备端Device之间的交互延迟。它并非 “内核生成工具”也非 “用户态编程接口”而是针对 CPU-GPU 交互瓶颈的底层执行优化技术抽象层级最低。其核心作用是固化连续的 CUDA 内核调用序列与内存配置通过 “录制 - 重放” 模式消除重复内核启动、CPU-GPU 频繁通信的开销仅优化执行流程不改变内核本身的计算性能。8.2. 核心联系1torch.compile 依赖 triton.jit 实现高性能内核生成torch.compile 的默认底层编译器Inductor在 CUDA 设备上会自动将 PyTorch 计算逻辑转化为 Triton 内核代码并隐式调用 triton.jit 完成编译生成高性能 GPU 内核开发者无需手动编写 Triton 代码也无需感知 triton.jit 的存在。此外torch.compile 也支持生成原生 CUDA 内核作为 Triton 内核的可选补充方案。2torch.compile 集成 CUDA Graph 实现执行层二次优化当输入张量的形状、数据类型等属性固定时torch.compile 会自动启用 CUDA Graph 优化将编译生成的 Triton/CUDA 内核调用序列录制为 CUDA 图。后续重复执行该逻辑时直接在 GPU 上重放该图进一步放大性能收益实现 “内核计算优化” 与 “执行流程优化” 的协同增效。3triton.jit 自定义内核可与 CUDA Graph 手动协同开发者手动通过 triton.jit 编写并编译的自定义内核在批量重复执行输入形状固定的场景下可手动集成 CUDA Graph 完成 “录制 - 重放” 流程消除 CPU 对 GPU 的调度开销实现内核计算性能与执行效率的双重极致优化。4三者协同构建极致性能计算链路典型极致性能链路手动编写 triton.jit 定制内核 → 嵌入 PyTorch 模型 / 函数 → 通过 torch.compile 进行上层计算图优化算子融合、内存复用等 → torch.compile 自动启用 CUDA Graph 优化执行流程 → 实现 GPU 计算性能最大化。9. TP 模式TP 模式将矩阵计算按行、列拆分到多颗 GPU 上执行涉及两个关键点权重参数怎么加载、多核计算之间如何协同下面做介绍。9.1. 加载权重参数权重参数与矩阵计算强相关因此权重参数的加载逻辑通常与矩阵计算逻辑一同封装在同一个类中实现功能的内聚性。9.1.1. 关键技术点1参数文件中权重矩阵以 Key-Value 键值对形式存储读取时同样采用 Key-Value 方式解析。其中 key 对应权重矩阵在模型中的归属位置例如模型第 0 层 MLP 子层的 down proj 权重对应的 key 为 model.layers.0.mlp.down_proj.weight。2参数文件由训练流程写入、推理流程读取训练与推理两侧必须严格对齐 key 的命名规则。模型参数加载时会根据参数文件中的 key 名称在 nn.Module 对象中匹配并调用对应的 weight_loader 方法完成加载。3模型结构包含多个层级每一层内部又包含多个子模块不同子模块对应各自专属的参数加载方法。PyTorch 的 nn.Module 通过特殊方法 __setattr__将模型结构中的各个子模块构建为树形结构树形结构中每个叶子节点的路径与参数文件中的 key 一一映射通过该路径找到叶子节点后即可获取对应的参数对象 nn.Parameter而该参数对象绑定了其所属子模块的 weight_loader 方法。4矩阵乘法 A * B 遵循「A 的行 × B 的列」计算规则在模型推理中B 为权重矩阵实际访问时以列维度为主。为提升读取效率、避免缓存Cache频繁失效权重矩阵 B 通常以转置形式存储。TPTensor Parallelworker 加载权重时需适配该转置存储特性 —— 即权重矩阵第 0 维对应原始矩阵的列数据第 1 维对应原始矩阵的行数据。从上述技术点可得出核心对应关系参数文件中的模型结构以一个个 key 表示这些 key 按层级关系可构建为一棵路径树代码中的模型结构以有包含关系的类对象表示这些类对象同样构成一棵与参数文件路径树完全对应的树。9.1.2. 实操举例FFN 层 up proj 权重加载1假设 TP size2up proj 权重矩阵的原始形状为 [1024, 3072]下面介绍一个 TP worker 如何加载权重。2首先构造模型对象时会初始化 ColumnParallelLinear 对象并设定核心参数input_size1024output_size3072/21536按 TP 尺寸做均分。这两个参数最终用于初始化 nn.Parameter 对象对应代码为 self.weight nn.Parameter(torch.empty(output_size, input_size))需注意此处初始化的张量以 output_size 为行维度、input_size 为列维度。3随后启动模型权重加载流程先从参数文件中读取所有 key-value 键值对再通过 key 在 nn.Module 树形结构中查找对应的 nn.Parameter 对象匹配到后调用其绑定的 weight_loader 函数执行具体的参数加载操作。4参数加载阶段针对列并行模式需要对权重张量的第 0 维度进行拆分再根据当前进程的 tp_rankTP 进程编号确定本进程需要加载的权重区间完成分片权重的加载。注代码实现中会将 gate 矩阵与 up 矩阵进行合并加载到显存中因此实际加载流程会在此基础上增加几步额外步骤。9.1.3. 构造树形结构示例代码下面 demo 代码展示多个有层级的对象如何通过特殊方法 __setattr__ 构造树形结构.class MiniModule: def __init__(self, nameroot): self._name name self._modules {} self._parameters {} def __setattr__(self, name, value): if isinstance(value, MiniModule): self._modules[name] value elif name.endswith(_weight_loader): self._parameters[name] value super().__setattr__(name, value) def get_all_paths(self, prefix): 递归遍历并收集所有参数的完整路径 paths [] # 1. 先收集当前层级的参数路径 for p_name in self._parameters: full_path f{prefix}.{p_name} if prefix else p_name paths.append(full_path) # 2. 递归进入子模块传递更新后的前缀 for m_name, m_obj in self._modules.items(): new_prefix f{prefix}.{m_name} if prefix else m_name paths.extend(m_obj.get_all_paths(new_prefix)) return paths def q_weight_loader(): print(fthis is q_weight_loader) def down_weight_loader(): print(fthis is down_weight_loader) # --- 构造树形结构 --- model MiniModule(Qwen3) model.layers MiniModule(Layers) model.layers.attention MiniModule(Attention) model.layers.attention.q_weight_loader q_weight_loader model.layers.mlp MiniModule(MLP) model.layers.mlp.down_weight_loader down_weight_loader # --- 打印所有路径 --- print(遍历模型的所有参数路径) all_paths model.get_all_paths() for path in all_paths: print(f路径: {path})

更多文章