《Nano-vLLM 源码解读》第 16 篇 · Linear 投影

发布时间:2026/6/9 2:54:00

《Nano-vLLM 源码解读》第 16 篇 · Linear 投影 nano-vllm 用千行代码拆解 vLLM 核心是读懂大模型推理最快的捷径。1. 介绍上一篇里 RoPE 旋转的 q、k是从self.qkv_proj(hidden)一次投影、再split出来的。q、k、v 本是三个独立的线性投影nano-vllm 把它们合并成了一次。本篇解读qkv_proj所属的 Linear 家族讲清一件事把 q/k/v、gate/up 合并成一次投影为什么能省一次 kernel 启动。linear.py支持张量并行多卡切分。本篇着重介绍投影简单起见统一按单卡解读多卡切分后续单开篇幅介绍。2. Linear 的本质一个线性层把输入的一组数重新加权组合成输出的另一组数——输出里的每个数都是输入那组数的一次加权求和写成式子是y Wx b。为什么需要模型里每个 token 的隐藏向量是一团「缠在一起」的信息——词形、粗粒度词性、模糊的语义、位置顺序全揉在同一组数里杂乱无章。后面的计算想用上其中某一类得先把这团乱麻「解开、对齐」成明确具体的特征。解决了什么Linear 把这团乱麻投影成一组「对齐好」的新方向每个方向对应一个有意义的问题、彼此分开。拿 “red car” 举例真实里它是两个 token这里理想化成一个 1024 维隐藏向量投影出 4096 维后可能其中一维问「是不是交通工具」、一维问「是不是红色 / 暖色」、一维问「是不是人造物」、一维问「能不能移动、跑多快」……怎么解决每个输出方向就是W的一行——一个「概念提取器」它给输入的每个特征分配一个权重再加权求和。「是不是交通工具」那一行会给「有轮子」「能载人」「在路上」这些输入特征高权重给「是不是红色」近乎零权重。W是d_out × d_in的权重矩阵、b是偏置前向就一句y Wx b。import torch import torch.nn.functional as F # 把图里那团 1024 维隐藏向量理想化成 5 个看得懂的特征0~1 表强弱 # 有轮子 能载人 在路上 是红色 会发光 x torch.tensor([ 1., 1., 1., 1., 0. ]) # W 的每一行 一个「对齐方向 / 概念」给各特征配权重再加权求和 W torch.tensor([ [1., 1., 1., 0., 0.], # 是不是交通工具看重 有轮子/能载人/在路上 [0., 0., 0., 2., 1.], # 是不是红色暖色看重 是红色/会发光 ]) b torch.tensor([0., 0.]) y F.linear(x, W, b) # 投影出两个方向的打分 print(交通工具分, 红色分 :, y) # tensor([3., 2.]) print(维度变换 :, x.shape[0], -, y.shape[0]) # 5 - 2交通工具分, 红色分 : tensor([3., 2.]) 维度变换 : 5 - 2这种「把一组特征线性重组成另一组」的操作是深度学习里最高频、最吃算力的环节——Transformer 一层里 qkv、o、gate、up、down 全是 LinearLinear 只做线性重组要有非线性表达力还得在两层之间夹一个激活函数。本篇要拆的qkv_proj就是一次1024 → 4096的线性投影只是它对齐出的方向供注意力当 q、k、v 用。3. 总览Linear 家族在 nano-vllm 里就一个文件linear.py一个基类LinearBase派生出五个子类。Qwen3 一层 decoder 里四处线性投影各用其中一个。类用在哪投影ReplicatedLinear-—ColumnParallelLinear仅作基类—QKVParallelLinearqkv_proj合并 q/k/v1024 → 4096MergedColumnParallelLineargate_up_proj合并 gate/up1024 → 6144RowParallelLinearo_proj/down_proj2048→1024 / 3072→1024按 forward 出场顺序四处投影是qkv_proj → o_proj → gate_up_proj → down_proj。其中两个是「合并投影」表中的QKV、Merged——把本该分开的几次投影并成一次正是本篇要介绍的。单卡下ColumnParallelLinear/RowParallelLinear/ReplicatedLinear的forward都退化成一句F.linear类名里的「Parallel」代表支持多卡并行。4. 合并投影q、k、v 三个投影输入的是同一个hidden。把三个权重矩阵在输出维上拼成一个大矩阵一次F.linear算出[N, 4096]再按split([2048, 1024, 1024])切回 q、k、v。gate、up 同理拼成一次[N, 6144]再切两半。打个比方三个人各跑一趟去同一个仓库取货不如开一辆大车一趟拉回来再分。为什么需要分开做是三次独立的矩阵乘等于三次 kernel 启动、hidden从显存读三遍。decode 每步只算一个 token矩阵很小这时启动一个 kernel 的开销甚至比算它本身还久——三次启动就是三倍的固定开销。解决了什么固定开销与显存读写都降到原来的三分之一一个大矩阵乘也比三个小的更能喂饱 GPU。怎么解决权重在输出维拼接成[204810241024, 1024]前向一次F.linear输出再split切回。三次投影的总输出维不变所以算的乘加次数FLOPs一点没少省下的是 kernel 启动和显存读写——和 L14 把加法融进归一化是同一个道理省的是访存与启动不是算力。# 直观看「合并 分开但只一次 matmul」 torch.manual_seed(0) x torch.randn(3, 1024) # 3 个 token 的 hidden[3, 1024] wq torch.randn(2048, 1024) # q 权重[2048, 1024] wk torch.randn(1024, 1024) # k 权重[1024, 1024] wv torch.randn(1024, 1024) # v 权重[1024, 1024] # 分开三次 F.linear三次 kernel 启动 q, k, v F.linear(x, wq), F.linear(x, wk), F.linear(x, wv) # shape: q [3, 2048] k [3, 1024] v [3, 1024] # 合并权重在输出维 cat 成一个一次 F.linear再 split 切回 w_qkv torch.cat([wq, wk, wv], dim0) # [204810241024, 1024] [4096, 1024] qkv F.linear(x, w_qkv) # [3, 1024] → [3, 4096]一次 q2, k2, v2 qkv.split([2048, 1024, 1024], dim-1) # shape: q2 [3, 2048] k2 [3, 1024] v2 [3, 1024]与分开版逐一对应 print(hidden x :, tuple(x.shape)) # (3, 1024) print(分开 q/k/v :, tuple(q.shape), tuple(k.shape), tuple(v.shape)) print(合并 qkv :, tuple(qkv.shape)) # (3, 4096) print(q 一致 :, torch.allclose(q, q2, atol1e-4)) # True print(k 一致 :, torch.allclose(k, k2, atol1e-4)) # True print(v 一致 :, torch.allclose(v, v2, atol1e-4)) # Truehidden x : (3, 1024) 分开 q/k/v : (3, 2048) (3, 1024) (3, 1024) 合并 qkv : (3, 4096) q 一致 : True k 一致 : True v 一致 : True5. LinearBaseLinearBase是 Linear 家族的基类——持有一张权重表weight外加可选bias并挂一个weight_loader钩子负责把磁盘权重填进来。import torch from torch import nn import torch.nn.functional as F class LinearBase(nn.Module): def __init__(self, input_size, output_size, biasFalse): super().__init__() # 权重矩阵 [输出维, 输入维] self.weight nn.Parameter(torch.empty(output_size, input_size)) self.weight.weight_loader self.weight_loader # 挂加载钩子 if bias: self.bias nn.Parameter(torch.empty(output_size)) self.bias.weight_loader self.weight_loader else: self.register_parameter(bias, None) def weight_loader(self, param, loaded_weight): # 默认整块直接拷 param.data.copy_(loaded_weight) def forward(self, x): # 单卡一句 F.linear return F.linear(x, self.weight, self.bias) # 三个非合并子类单卡下都没有额外动作forward 直接用基类那句 F.linear class ReplicatedLinear(LinearBase): # 单卡整份Qwen3 未用 pass class ColumnParallelLinear(LinearBase): # 单卡整份多卡按输出维切后续介绍 pass class RowParallelLinear(LinearBase): # 单卡整份多卡按输入维切后续介绍 pass6. 实现合并类两个合并类都继承ColumnParallelLinear各做两件事__init__把几路投影的输出维拼成一个大矩阵weight_loader在加载时把磁盘上分开存的几份权重按段填回这块合并参数。是什么weight_loader是挂在参数上的钩子按一个shard_id把某一份磁盘张量拷进合并参数里属于它的那一段行。打个比方合并参数像一个分格的抽屉柜每份权重各自归位到对应格子shard_id是格子编号weight_loader是放进去的动作。为什么需要默认加载就一句param.data.copy_(loaded_weight)要求名字与形状一一对应。合并把几份权重并成一块破坏了这个对应必须有钩子按段填。怎么解决weight_loader先用narrow框出本段在合并参数里该占的行段再把磁盘张量copy_进去。两个合并类的区别只在分几段、偏移怎么算——gate/up 是相等的两段qkv 是按头数算的三段。MergedColumnParallelLineargate/up两段gate、up 两路输出维相等合并成gate_up_proj.weight1024 → 6144。weight_loader按output_sizes累加偏移填段gate 占[0, 3072)shard_id0up 占[3072, 6144)shard_id1。# __init__ 把输出维求和weight_loader 按 output_sizes 累加偏移填段。 class MergedColumnParallelLinear(ColumnParallelLinear): def __init__(self, input_size, output_sizes, biasFalse): self.output_sizes output_sizes # 各段输出维如 [3072, 3072] # 合并参数的输出维 各段之和 super().__init__(input_size, sum(output_sizes), bias) def weight_loader(self, param, loaded_weight, loaded_shard_id): # loaded_shard_id: 0gate, 1up # 本段起始行 前面各段输出维之和gate→0, up→3072 shard_offset sum(self.output_sizes[:loaded_shard_id]) shard_size self.output_sizes[loaded_shard_id] # 3072 # narrow 框出该行段dim0按行输出维再拷进去 param_data param.data.narrow(0, shard_offset, shard_size) param_data.copy_(loaded_weight) # ① __init__合并权重输出维 各段之和 m MergedColumnParallelLinear(1024, [3072, 3072]) # gate, up print(gate_up 合并权重 :, tuple(m.weight.shape)) # (6144, 1024) # ② weight_loader两份磁盘权重用可辨认常数填按 shard_id 落到对应段 gate_w torch.full((3072, 1024), 1.) up_w torch.full((3072, 1024), 2.) m.weight_loader(m.weight, gate_w, 0) # shard_id 0 → [0, 3072) m.weight_loader(m.weight, up_w, 1) # shard_id 1 → [3072, 6144) print(gate 段(前 3072) :, m.weight[:3072].unique().tolist()) # [1.0] print(up 段(后 3072) :, m.weight[3072:].unique().tolist()) # [2.0]gate_up 合并权重 : (6144, 1024) gate 段(前 3072) : [1.0] up 段(后 3072) : [2.0]QKVParallelLinearq/k/v三段q、k、v 头数不同q 16 头、k/v 各 8 头每头 128 维合并成qkv_proj.weight1024 → 4096。weight_loader按头数算偏移填段q 占[0, 2048)、k 占[2048, 3072)、v 占[3072, 4096)shard_id分别为q/k/v。# __init__ 按头数算输出维weight_loader 按 q/k/v 三段填。 class QKVParallelLinear(ColumnParallelLinear): def __init__(self, hidden_size, head_size, total_num_heads, total_num_kv_headsNone, biasFalse): total_num_kv_heads total_num_kv_heads or total_num_heads self.head_size head_size self.num_heads total_num_heads # q 头数 16 self.num_kv_heads total_num_kv_heads # k/v 头数 8 # 输出维 (q头数 k头数 v头数) × head_size output_size (total_num_heads 2 * total_num_kv_heads) * head_size super().__init__(hidden_size, output_size, bias) def weight_loader(self, param, loaded_weight, loaded_shard_id): assert loaded_shard_id in [q, k, v] if loaded_shard_id q: shard_size self.num_heads * self.head_size # 16×128 2048 shard_offset 0 elif loaded_shard_id k: shard_size self.num_kv_heads * self.head_size # 8×128 1024 shard_offset self.num_heads * self.head_size # 偏移 2048 else: # v shard_size self.num_kv_heads * self.head_size # 1024 # 偏移 q 段 k 段 3072 shard_offset (self.num_heads self.num_kv_heads) * self.head_size # narrow 框出合并参数里这一段行再拷进去 param_data param.data.narrow(0, shard_offset, shard_size) param_data.copy_(loaded_weight) # ① __init__Qwen3-0.6B输出维 (16 2×8) × 128 qkv QKVParallelLinear(hidden_size1024, head_size128, total_num_heads16, total_num_kv_heads8) print(qkv 合并权重 :, tuple(qkv.weight.shape)) # (4096, 1024) # ② weight_loaderq/k/v 三份磁盘权重用可辨认常数填按 shard_id 落段 q_w torch.full((2048, 1024), 1.) # q: 16×128 k_w torch.full((1024, 1024), 2.) # k: 8×128 v_w torch.full((1024, 1024), 3.) # v: 8×128 qkv.weight_loader(qkv.weight, q_w, q) # → [0, 2048) qkv.weight_loader(qkv.weight, k_w, k) # → [2048, 3072) qkv.weight_loader(qkv.weight, v_w, v) # → [3072, 4096) print(q 段 :, qkv.weight[:2048].unique().tolist()) # [1.0] print(k 段 :, qkv.weight[2048:3072].unique().tolist()) # [2.0] print(v 段 :, qkv.weight[3072:].unique().tolist()) # [3.0]qkv 合并权重 : (4096, 1024) q 段 : [1.0] k 段 : [2.0] v 段 : [3.0]路由表 packed_modules_mapping合并参数能从分开的磁盘张量拼出来还差一张对照表packed_modules_mapping记录「磁盘上的q_proj→ 合并参数qkv_proj」。加载时遍历权重文件里的每条权重据此把磁盘名换成合并参数名、取出shard_id再用model.get_parameter拿到那块合并参数param调它的weight_loader把张量填进对应段。# Qwen3ForCausalLM.packed_modules_mapping磁盘名 → (合并参数名, shard_id) packed_modules_mapping { q_proj: (qkv_proj, q), k_proj: (qkv_proj, k), v_proj: (qkv_proj, v), gate_proj: (gate_up_proj, 0), up_proj: (gate_up_proj, 1), } # load_model 的路由逻辑摘自 loader.py # for weight_name in 权重文件: # 如 ...q_proj.weight # for k in packed_modules_mapping: # 命中 q_proj # if k in weight_name: # v, shard_id packed_modules_mapping[k] # (qkv_proj, q) # param_name weight_name.replace(k, v) # ...qkv_proj.weight # param model.get_parameter(param_name) # param.weight_loader(param, tensor, shard_id) # 带 shard_id 填段 # break # else: # for 未 break无命中才进 else # param.weight_loader(param, tensor) # 普通参数默认 copy无 shard_id print(q_proj 路由到 :, packed_modules_mapping[q_proj]) # (qkv_proj, q) print(gate_proj 路由 :, packed_modules_mapping[gate_proj]) # (gate_up_proj, 0)q_proj 路由到 : (qkv_proj, q) gate_proj 路由 : (gate_up_proj, 0)7. 集成验证加载真实 Qwen3-0.6B取第 0 层的qkv_proj一个真实QKVParallelLinear验证「合并一次 matmul split」与「分开按段三次投影」逐元素一致——这正是合并省 launch 的前提少两次启动结果不变。import torch import torch.distributed as dist import torch.nn.functional as F from modelscope import snapshot_download from nanovllm.config import Config # 复用 L11 教学版 ModelRunner 加载真实权重 from topic11_model_runner import ModelRunner torch.cuda.set_device(0) if not dist.is_initialized(): dist.init_process_group( nccl, tcp://localhost:2335, world_size1, rank0) model_path snapshot_download(Qwen/Qwen3-0.6B) config Config(model_path, enforce_eagerTrue, max_model_len4096) runner ModelRunner(config) model runner.model # Qwen3ForCausalLM权重 bf16Downloading Model from https://www.modelscope.cn to directory: /DATA/disk5/cache/modelscope/models/Qwen/Qwen3-0.6B 2026-06-07 19:12:53,732 - modelscope - INFO - Target directory already exists, skipping creation.attn0 model.model.layers[0].self_attn qkv attn0.qkv_proj # 真实 QKVParallelLinear qs, kvs attn0.q_size, attn0.kv_size # 2048, 1024 assert qkv.bias is None # Qwen3 用 QK-Norm 替代 qkv bias # 造一份 hiddendtype 跟权重一致 torch.manual_seed(0) hidden torch.randn(4, 1024, devicecuda, dtypeqkv.weight.dtype) with torch.inference_mode(): # 合并一次 matmul再 split merged qkv(hidden) # [4, 4096] q, k, v merged.split([qs, kvs, kvs], dim-1) # 分开把合并权重按行段切出 q/k/v各做一次 F.linear W qkv.weight q2 F.linear(hidden, W[0:qs]) k2 F.linear(hidden, W[qs:qs kvs]) v2 F.linear(hidden, W[qs kvs:qs 2 * kvs]) print(合并输出维 :, tuple(merged.shape)) # (4, 4096) print(q/k/v 偏移 :, 0, qs, qs kvs) # 0 2048 3072 print(q 合并分开 :, torch.allclose(q, q2, atol1e-3)) # True print(k 合并分开 :, torch.allclose(k, k2, atol1e-3)) # True print(v 合并分开 :, torch.allclose(v, v2, atol1e-3)) # True8. 小结Linear 家族一个基类派生五个子类共性是一个weight加一个weight_loader钩子。两个合并类是本篇核心QKVParallelLinear把 q/k/v 拼成一次1024→4096的投影MergedColumnParallelLinear把 gate/up 拼成一次1024→6144。合并不改变算的乘加次数省下的是 kernel 启动与显存读写——三次小投影并成一次大的。代价是加载时要把磁盘上分开的q_proj/k_proj/v_proj按shard_id拼回合并参数这由weight_loaderpacked_modules_mapping完成。下一篇讲解注意力层中qkv切出的 q、k、v 怎么过 QK-Norm、RoPE、attention再经o_proj输出。

相关新闻