8-bit量化LLM与Lightning Fabric集成实战

发布时间:2026/6/8 10:29:59

8-bit量化LLM与Lightning Fabric集成实战 1. 项目概述为什么一个8-bit量化的小语言模型值得用Lightning Fabric重跑一遍最近在实验室搭新环境时顺手把Llama-3-8B-Instruct做了次全链路量化复现——不是用Hugging Face的bitsandbytes一键load_in_8bitTrue那种黑盒调用而是从模型权重加载、张量切分、量化映射、反量化重建到前向推理全程可控、可调试、可插桩。过程中发现一个被很多人忽略的事实8-bit量化本身不难难的是在分布式训练/推理框架里让量化逻辑不和数据并行、FSDP、梯度检查点这些高级特性打架。这时候Lightning Fabric的价值就凸显出来了——它不像PyTorch Lightning那样自带训练循环绑定也不像纯torch.distributed那样要手写所有通信逻辑而是在“完全暴露底层控制权”和“自动处理跨设备一致性”之间卡了一个极精准的平衡点。核心关键词“8-Bit LLM Quantization”和“Lightning Fabric”其实代表了两类工程师的交汇一类是模型优化工程师天天和int8、fp4、group-wise quantization、activation-aware weight quantizationAWQ打交道另一类是MLOps/基础设施工程师关心device placement、tensor sharding、gradient sync barrier、checkpointing compatibility。这个项目就是为这两类人写的前者能拿到可复现、可修改、可注入自定义量化策略的干净代码基线后者能直接把这套流程塞进现有Fabric驱动的集群调度流水线里不用改一行训练器封装逻辑。适合谁来读如果你正面临这几个真实场景中的任意一个这篇就是为你准备的想把已有的FP16大模型服务压到单张A10或RTX 4090上跑通推理但bitsandbytes在多卡DDP下报CUDA error: device-side assert triggered在做QLoRA微调时发现LoRA适配器权重和量化主干权重的梯度更新节奏不一致loss曲线抖得像心电图需要给客户交付一个“可审计”的量化模型——不是只给个.safetensors文件而是能清晰展示每一层权重如何从FP16→INT8→反量化回FP16的映射表、scale值、zero-point偏移或者你只是单纯想搞懂为什么同样用torch.int8自己手写的量化kernel比bnb.nn.Linear8bitLt慢3倍瓶颈到底在内存带宽、kernel launch overhead还是tensor layout没对齐我试过三种主流路径纯torchmanual quantization、HF Transformersbnb、以及这次的Lightning Fabric方案。实测下来Fabric方案在可调试性、跨设备一致性、与现有训练脚本兼容性这三点上稳赢。它不承诺“一键加速”但承诺“每一步你都看得见、改得了、测得准”。下面我们就从设计思路开始一层层拆开这个8-bit量化LLM在Fabric下的真实工作流。2. 整体架构设计为什么放弃bitsandbytes选择手动量化Fabric原生调度2.1 不选bitsandbytes的三个硬伤先说清楚bitsandbytes是工业界事实标准我日常也用。但它在这个项目里被主动排除原因很具体不可调试的内核黑箱bnb.nn.Linear8bitLt内部调用CUDA kernel做weight-only 8-bit matmul但源码里没有公开kernel实现只有预编译so也没有提供forward中间态hook点。你想知道某一层的量化误差分布不行。想在反向传播时插入梯度裁剪kernel里没留接口。我们曾为查一个nan梯度溯源三天最后发现是某个layer norm的输入在量化后溢出但bnb根本不暴露量化前后的tensor对比。DDP兼容性灾难bnb默认假设单卡推理。当你用DistributedDataParallel包装模型时它会尝试把Linear8bitLt层的quant_state含scale、zero-point等广播到所有rank但这些state本身是torch.Tensor而bnb的quant_state结构体里混着numpy.ndarray和torch.dtype对象在torch.distributed.broadcast_object_list里直接报PicklingError。社区issue里有27个相关讨论官方回复是“建议用FSDP替代DDP”——可你的旧训练脚本就是基于DDP写的重构成FSDP要改300行。量化策略锁定死bnb只支持NF4用于QLoRA和INT8用于推理且INT8固定用row-wise quantization每行一个scale。但实际业务中我们发现对Llama-3的q_proj层用group-wise每128个weight一组量化比row-wise低0.8% perplexity对o_proj层用channel-wise每列一个scale反而更稳。bnb不开放scale计算逻辑你连改个分组大小都做不到。提示这不是批评bitsandbytes而是明确它的定位——它是为“快速上线”设计的不是为“深度优化”设计的。本项目目标是后者。2.2 Fabric为何成为最优解三重解耦设计Lightning Fabric的核心价值在于它把“计算逻辑”、“设备调度”、“状态管理”彻底解耦。我们正是利用这一点构建了三层隔离架构层级职责Fabric提供的能力我们注入的量化逻辑计算层定义forward/backward数学逻辑fabric.to_device()、fabric.no_backward_sync()手写QuantizedLinear模块含quantize_weight()、dequantize_weight()、fake_quantize_activation()调度层管理多卡/多节点tensor placementfabric.setup()自动处理DistributedSampler、FSDP、DDP切换用fabric.broadcast()同步各rank的量化参数scale/zero-point用fabric.barrier()确保量化状态一致状态层保存/加载模型权重与优化器状态fabric.save()/fabric.load()支持safetensors、torch.save双格式自定义save_quantized_state_dict()把weight_int8、scale、zero_point、quant_config打包存为独立.pt这个设计的关键在于量化逻辑完全写在QuantizedLinear里和Fabric零耦合Fabric只负责把QuantizedLinear放到正确设备、同步必要状态、保存正确格式。这意味着你可以今天用Fabric跑DDP明天切到FSDP后天换到单卡QuantizedLinear的代码一行不用动。2.3 为什么是8-bit而不是4-bit或6-bit有人问既然都量化了为啥不直接上NF4答案是8-bit是精度、速度、兼容性的黄金交点。精度在Llama-3-8B上INT8量化AWQ校准后perplexity仅比FP16高1.2%而NF4高3.7%。尤其在长文本生成时NF4的累积误差会导致主题漂移——我们测试过一篇2000词的技术文档摘要NF4版本漏掉了3个关键术语INT8版本全部保留。速度INT8matmul在A100上实测吞吐是FP16的1.8倍NF4只有1.3倍。因为NF4需要额外dequantize步骤deq (x - zero_point) * scale而INT8可直接用cublasLtMatmul硬件加速。我们用Nsight Compute抓帧发现NF4kernel的stall_inst_fetch占比高达22%INT8仅9%。兼容性所有主流推理引擎vLLM、TGI、llama.cpp都原生支持INT8权重加载但NF4需额外转换。比如llama.cpp要求NF4权重必须转成gguf的Q4_K_M格式而我们的量化pipeline输出的是标准torch.int8tensor直接torch.save就能喂给vLLM的tensor_parallel_size4部署。所以结论很务实除非你GPU显存16GB且能接受精度损失否则8-bit是当前LLM量化落地最稳的选择。本项目所有参数、配置、代码都围绕这个前提展开。3. 核心细节解析8-bit量化的四个技术锚点与Fabric集成要点3.1 锚点一Weight Quantization——不是简单除以127很多人以为8-bit量化就是weight_int8 torch.round(weight_fp16 / max_abs * 127)。错。这叫symmetric quantization对LLM权重效果极差。真正有效的是asymmetric quantizationgroup-wise scaling。我们采用AWQ论文里的标准做法对每组G128个weight计算min_w和max_wscale_g (max_w - min_w) / 255.0zero_point_g round(-min_w / scale_g)weight_int8_g round(weight_fp16_g / scale_g zero_point_g)最终weight_int8是int8tensorscale和zero_point是float32和int32tensorshape为(num_groups,)。为什么G128因为Llama-3的q_proj权重shape是[4096, 4096]按G128分组后num_groups 4096*4096/128 131072scaletensor仅占131072*4512KB显存而weight_int8占4096*4096*164MB。这个比例保证了量化参数开销1%同时分组足够细能捕捉局部weight分布差异。Fabric集成要点scale和zero_point必须注册为nn.Parameter而非buffer因为它们在QLoRA微调时需要参与梯度更新。我们这样写class QuantizedLinear(nn.Module): def __init__(self, in_features, out_features, group_size128): super().__init__() self.in_features in_features self.out_features out_features self.group_size group_size # weight_int8 is buffer (no grad) self.register_buffer(weight_int8, torch.empty(out_features, in_features, dtypetorch.int8)) # scale and zero_point are parameters (need grad for QLoRA) self.scale nn.Parameter(torch.empty(out_features, in_features // group_size)) self.zero_point nn.Parameter(torch.empty(out_features, in_features // group_size, dtypetorch.int32))注意self.scale和self.zero_point的shape是(out_features, num_groups)不是(num_groups,)。这是为了后续torch.einsum做group-wise dequant时能自动broadcast避免手写for-loop。3.2 锚点二Activation Quantization——只量化KV Cache不动输入LLM推理中最大的显存杀手是KV Cache不是权重。Llama-3-8B在seq_len2048时KV Cache占显存约12GBFP16。我们只对KV Cache做8-bit量化输入x保持FP16——这是精度和显存的务实妥协。具体操作在Attention.forward()里k和v计算完后立即量化k_int8, k_scale, k_zp quantize_per_token(k) # shape [bs, nh, sl, hs] v_int8, v_scale, v_zp quantize_per_token(v)量化方式用per-token symmetric对每个token位置即k[:, :, i, :]算max_abs k[:, :, i, :].abs().max()然后k_int8[:, :, i, :] torch.round(k[:, :, i, :] / max_abs * 127)。反量化时k_fp16 k_int8.float() * k_scale.unsqueeze(-1)这里k_scale是[bs, nh, sl]自动broadcast到[bs, nh, sl, hs]。为什么只量化KV Cache因为输入x经过RMSNorm后分布非常集中量化会引入明显噪声KV Cache每个token独立per-token量化误差不会跨位置传播实测显示KV Cache 8-bit量化后生成质量无损BLEU差0.1但显存直降42%。Fabric要点KV Cache量化必须在fabric.device上执行且k_scale/k_zp需用fabric.broadcast()同步所有rank。我们封装成KVCacheQuantizer类在fabric.setup()后初始化class KVCacheQuantizer: def __init__(self, fabric: Fabric): self.fabric fabric # broadcast scale/zp from rank 0 to all ranks self.k_scale fabric.broadcast(torch.empty(1, dtypetorch.float32)) self.v_scale fabric.broadcast(torch.empty(1, dtypetorch.float32)) def quantize(self, k: Tensor, v: Tensor) - Tuple[Tensor, Tensor, Tensor, Tensor]: k_int8 torch.round(k / self.k_scale * 127).to(torch.int8) v_int8 torch.round(v / self.v_scale * 127).to(torch.int8) return k_int8, self.k_scale, v_int8, self.v_scale3.3 锚点三Dequantization Kernel——手写CUDA还是用torch.compiledequantize_weight()是性能瓶颈。weight_int8是[out, in]scale是[out, num_groups]zero_point是[out, num_groups]要还原成[out, in]的FP16需做# naive pytorch (SLOW) deq torch.zeros_like(weight_int8, dtypetorch.float16) for i in range(out_features): for g in range(num_groups): start g * group_size end min(start group_size, in_features) deq[i, start:end] (weight_int8[i, start:end].float() - zero_point[i, g]) * scale[i, g]这段代码在A100上跑一次q_proj[4096, 4096]要18ms。我们测试了两种加速方案方案Atorch.compile channels-last把weight_int8reshape成[out, num_groups, group_size]用torch.einsum(og,og-og, ...)批量计算再view回[out, in]。torch.compile(modereduce-overhead)后降到3.2ms。方案B手写CUDA kernel用Triton写dequantize_rowwise显存带宽利用率从42%提到89%降到1.7ms。但开发成本高且需维护.cu文件。最终选方案A——用torch.compile足够且和Fabric无缝兼容。Fabric的fabric.setup()会自动把compile后的module放到正确设备无需额外处理。我们这样封装class QuantizedLinear(nn.Module): def __init__(self, ...): ... # compile the dequant kernel once self._dequant_fn torch.compile(self._dequant_kernel, modereduce-overhead) def _dequant_kernel(self, weight_int8: Tensor, scale: Tensor, zp: Tensor) - Tensor: # reshape to [out, num_groups, group_size] out, in_feat weight_int8.shape num_groups in_feat // self.group_size weight_3d weight_int8.view(out, num_groups, self.group_size) # broadcast scale/zp to [out, num_groups, group_size] scale_3d scale.unsqueeze(-1) zp_3d zp.unsqueeze(-1).to(torch.float32) # dequantize deq_3d (weight_3d.float() - zp_3d) * scale_3d return deq_3d.view(out, in_feat)实测torch.compile在Fabric下稳定生效fabric.device自动识别为cuda:0无需torch.cuda.synchronize()。3.4 锚点四Gradient Flow——QLoRA微调时的量化梯度陷阱QLoRA微调时QuantizedLinear的weight_int8不参与梯度但scale和zero_point要更新。问题来了scale的梯度怎么算根据链式法则∂L/∂scale ∂L/∂deq * ∂deq/∂scale ∂L/∂deq * (weight_int8.float() - zero_point)但weight_int8是离散的∂weight_int8/∂scale 0所以梯度只来自dequant公式。然而weight_int8本身是round(weight_fp16 / scale zp)round函数导数为0导致∂L/∂scale在反向传播时为0——这就是著名的“straight-through estimator”STE问题。解决方案在backward时用weight_fp16替代weight_int8计算梯度class STEQuantize(torch.autograd.Function): staticmethod def forward(ctx, weight_fp16, scale, zp, group_size): weight_int8 torch.round(weight_fp16 / scale zp).to(torch.int8) ctx.save_for_backward(weight_fp16, scale, zp) ctx.group_size group_size return weight_int8 staticmethod def backward(ctx, grad_weight_int8): weight_fp16, scale, zp ctx.saved_tensors # use weight_fp16 to compute gradient, not weight_int8 grad_scale (grad_weight_int8.float() * (weight_fp16 / scale**2)).sum(dim1, keepdimTrue) grad_zp grad_weight_int8.float().sum(dim1, keepdimTrue) return grad_weight_int8.float() * scale, grad_scale, grad_zp, NoneFabric要点STEQuantize.apply()必须在fabric.device上运行且grad_scale/grad_zp需参与fabric.backward()的梯度同步。我们验证过在DDP模式下fabric.backward(loss)会自动all-reducescale.grad和zp.grad无需额外torch.distributed.all_reduce()。注意这是QLoRA微调的关键。如果跳过STEscale梯度为0微调几轮后量化误差爆炸loss直接nan。4. 实操过程从零搭建可复现的8-bit LLM量化Pipeline4.1 环境准备与依赖安装——精确到patch version别信“pip install lightning”这种模糊指令。本项目对版本极其敏感以下是经实测通过的组合# Python 3.10.12 (必须因PyTorch 2.3.1不支持3.11) conda create -n quant-fabric python3.10.12 conda activate quant-fabric # PyTorch 2.3.1 CUDA 12.1 (A100/H100标配) pip3 install torch2.3.1 torchvision0.18.1 torchaudio2.3.1 --index-url https://download.pytorch.org/whl/cu121 # Lightning Fabric 2.3.3 (非最新版2.4.0有FSDP量化bug) pip install lightning2.3.3 # 其他必需依赖 pip install transformers4.41.2 datasets2.19.1 safetensors0.4.3 sentencepiece0.2.0为什么锁死lightning2.3.3因为2.4.0在fabric.setup(model, optimizer)时会错误地把nn.Parameter类型的scale和zero_point当普通buffer处理导致FSDP分片时scale被切碎反向传播时报RuntimeError: Trying to backward through the graph a second time。这个bug在2.3.3里不存在。验证环境是否正确import torch from lightning import Fabric print(fPyTorch: {torch.__version__}) # 应输出 2.3.1cu121 print(fFabric: {Fabric.__version__}) # 应输出 2.3.3 # 测试Fabric setup fabric Fabric(acceleratorcuda, devices2, strategyddp) fabric.launch()如果fabric.launch()成功启动2个进程且无报错环境就绪。4.2 模型加载与量化校准——AWQ校准的3个关键步骤我们不用transformers.AutoModelForCausalLM.from_pretrained(..., load_in_8bitTrue)而是手动加载、校准、量化。完整流程分三步步骤1加载原始FP16模型from transformers import AutoModelForCausalLM, AutoTokenizer model_name meta-llama/Meta-Llama-3-8B-Instruct tokenizer AutoTokenizer.from_pretrained(model_name) # 加载FP16权重不量化 model AutoModelForCausalLM.from_pretrained( model_name, torch_dtypetorch.float16, device_mapcpu, # 先加载到CPU避免显存爆 low_cpu_mem_usageTrue )注意device_mapcpu——这是关键。如果设为autotransformers会把部分layer加载到GPU导致后续量化时CPU/GPU tensor混用报RuntimeError: Expected all tensors to be on the same device。步骤2AWQ校准数据准备AWQ需要少量256条校准样本不是随机抽而是选能激发模型极端激活的句子。我们用如下策略从C4数据集抽256条每条长度≥512用原始FP16模型跑一遍记录所有nn.Linear层的input和weight的abs().max()按max_activation降序排列取top-256这些句子往往包含大量数字、专有名词、嵌套从句能触发weight的full dynamic range。校准数据加载代码from datasets import load_dataset def get_awq_calibration_data(tokenizer, n_samples256, seq_len512): # load c4, filter long samples dataset load_dataset(c4, en, splittrain, streamingTrue) samples [] for sample in dataset: text sample[text] if len(tokenizer.encode(text)) seq_len: continue inputs tokenizer( text[:2000], # truncate to avoid OOM return_tensorspt, max_lengthseq_len, truncationTrue ) if inputs[input_ids].shape[1] seq_len//2: samples.append(inputs) if len(samples) n_samples: break return samples calibration_data get_awq_calibration_data(tokenizer)步骤3逐层AWQ校准与量化校准核心是找alpha参数使weight * alpha的量化误差最小。AWQ论文给出闭式解但我们用更鲁棒的grid searchdef awq_calibrate_layer(layer: nn.Linear, calibration_data, group_size128) - Tuple[Tensor, Tensor]: # layer.weight: [out, in], FP16 weight layer.weight.data.clone() out_features, in_features weight.shape # grid search alpha in [0.5, 1.0] step 0.05 best_alpha 1.0 best_error float(inf) for alpha in torch.arange(0.5, 1.05, 0.05): # apply alpha to weight w_alpha weight * alpha # quantize per group scale, zp compute_scale_zp(w_alpha, group_size) # dequantize w_deq dequantize(w_alpha, scale, zp, group_size) # compute mse error error torch.mean((w_alpha - w_deq) ** 2) if error best_error: best_error error best_alpha alpha # final quantize with best alpha w_opt weight * best_alpha scale, zp compute_scale_zp(w_opt, group_size) weight_int8 quantize(w_opt, scale, zp, group_size) return weight_int8, scale, zp # apply to all Linear layers for name, module in model.named_modules(): if isinstance(module, nn.Linear) and lm_head not in name: print(fCalibrating {name}...) weight_int8, scale, zp awq_calibrate_layer(module, calibration_data) # replace with QuantizedLinear quant_module QuantizedLinear( module.in_features, module.out_features, group_size128 ) quant_module.weight_int8.copy_(weight_int8) quant_module.scale.data.copy_(scale) quant_module.zero_point.data.copy_(zp) # replace in model parent_name ..join(name.split(.)[:-1]) parent dict(model.named_modules())[parent_name] setattr(parent, name.split(.)[-1], quant_module)整个校准过程在单卡A100上耗时约23分钟产出weight_int8、scale、zero_point三个tensor。4.3 Fabric集成与分布式训练——DDP/FSDP一键切换量化模型搭好后用Fabric启动训练。关键在fabric.setup()的参数选择DDP模式推荐初学者fabric Fabric( acceleratorcuda, devices4, strategyddp, # DistributedDataParallel precisionbf16-true # bfloat16 for speed, true means no autocast ) # setup model and optimizer model fabric.setup(model) optimizer torch.optim.AdamW(model.parameters(), lr2e-5) optimizer fabric.setup_optimizers(optimizer) # train loop for batch in fabric.setup_dataloaders(train_dataloader): optimizer.zero_grad() loss model(**batch).loss fabric.backward(loss) # auto all-reduce gradients optimizer.step()DDP模式下QuantizedLinear.scale和zero_point的梯度会被自动all-reduce无需额外代码。FSDP模式推荐生产环境from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy fabric Fabric( acceleratorcuda, devices8, strategyfsdp, # FullyShardedDataParallel precisionbf16-true, # FSDP specific config fsdp_sharding_strategyFULL_SHARD, fsdp_state_dict_typeSHARDED_STATE_DICT ) # wrap model with FSDP before setup auto_wrap_policy size_based_auto_wrap_policy model FSDP( model, auto_wrap_policyauto_wrap_policy, sharding_strategyfabric.strategy.sharding_strategy, device_idtorch.cuda.current_device() ) model fabric.setup(model)FSDP模式下scale和zero_point会被分片但fabric.backward()仍能正确同步梯度。我们实测8卡FSDP比4卡DDP快2.1倍且显存占用降低35%。4.4 推理部署与性能压测——vLLM vs 自研Fabric推理量化模型训练完要部署推理。我们对比两种方案方案1vLLM部署最快上线vLLM原生支持AWQ格式只需把量化权重转成gguf# 使用llama.cpp的convert.py python llama.cpp/convert.py \ --outtype f16 \ --outfile model-awq.gguf \ --quantize awq \ --model path/to/quantized/model然后启动vLLMvllm serve meta-llama/Meta-Llama-3-8B-Instruct \ --quantization awq \ --awq-ckpt model-awq.gguf \ --tensor-parallel-size 4实测QPS128 req/sA100×4P99延迟420ms。方案2Fabric原生推理最高可控我们写了一个FabricInferenceEngine直接加载.safetensorsclass FabricInferenceEngine: def __init__(self, model_path: str, fabric: Fabric): self.fabric fabric # load quantized model state_dict load_file(f{model_path}/model.safetensors) self.model QuantizedLlamaForCausalLM(...) # custom class self.model.load_state_dict(state_dict) self.model self.fabric.setup_module(self.model) def generate(self, prompt: str, max_new_tokens128): inputs self.tokenizer(prompt, return_tensorspt).to(self.fabric.device) with torch.no_grad(): outputs self.model.generate( **inputs, max_new_tokensmax_new_tokens, do_sampleFalse ) return self.tokenizer.decode(outputs[0])优势可插入任意hook如log每层KV Cache量化误差、可动态切分batch、可和训练pipeline共用同一套量化逻辑。实测QPS98 req/s略低于vLLM但P99延迟更稳380±20ms vs 420±80ms。实操心得vLLM适合“我要今天就上线”Fabric推理适合“我要长期迭代、加监控、做AB测试”。5. 常见问题与排查技巧实录那些文档里不会写的坑5.1 问题速查表高频报错与根因分析报错信息根因解决方案触发场景RuntimeError: Expected all tensors to be on the same deviceweight_int8在CPUscale在GPU或反之所有量化tensor统一用fabric.to_device()不要用.cuda()模型加载后未用fabric.setup()NaN loss during QLoRA trainingscale梯度为0导致量化误差累积确认STEQuantize已启用且scale是nn.Parameter忘记self.scale nn.Parameter(...)写成self.register_buffer()CUDA out of memorywhen loading modeltransformers的device_mapauto把部分layer加载到GPU加载时强制device_mapcpu量化后再fabric.to_device()初次加载未指定device_mapAllReduce failed: invalid argumentin DDPbnb的quant_state含numpy.ndarray无法broadcast彻底弃用bnb用纯torch量化试图在Fabric中混用bnb和QuantizedLinearvLLM fails to load AWQ modelvLLM要求scale/zero_point是float16但我们的scale是float32保存时scale scale.half()或vLLM启动加--dtype half量化后直接喂vLLM未转dtype5.2 独家避坑技巧从37次失败中总结技巧1量化前先做weight clippingLLM权重常有极少数outlier如10.0直接量化会拉垮scale。我们在awq_calibrate_layer前加clip# clip outliers to [-6, 6] (covers 99.9% of Llama-3 weights) weight torch.clamp(weight, -6.0, 6.0)实测clip后perplexity降低0.4%且训练稳定性提升。技巧2KV Cache量化用per-token但scale缓存复用每次per-token算max_abs太慢。我们缓存最近10个token的k_scale新token用滑动窗口更新class KVCacheQuantizer: def __init__(self): self.k_scale_cache deque(maxlen10) # cache last 10 scales def quantize(self, k): if self.k_scale_cache: # use median of cache as initial scale init_scale torch.median(torch.stack(list(self.k_scale_cache))) else: init_scale k.abs().max() # refine with current token k_scale torch.max(init_scale, k.abs().max()) self.k_scale_cache.append(k_scale) return torch.round(k / k_scale * 127).to(torch.int8), k_scale提速40%且精度无损。技巧3Fabric checkpoint必须用safetensorstorch.save会把QuantizedLinear的weight_int8存成torch.int8但某些旧版PyTorch加载时报Unsupported dtype

相关新闻