自动微分(AD)原理与工程实践:从链式法则到PyTorch反向传播

发布时间:2026/5/23 22:53:36

自动微分(AD)原理与工程实践:从链式法则到PyTorch反向传播 1. 这不是数学课是工程师手里的“求导加速器”你有没有在调试一个神经网络时盯着损失曲线发呆心里默念“为什么梯度又爆炸了”或者写完一个自定义的损失函数对着 PyTorch 的torch.autograd.grad文档反复确认参数顺序生怕一个retain_graphTrue漏掉就让整个训练流程卡死又或者在实现一个物理仿真模型时手动推导雅可比矩阵推到第三页草稿纸发现有个负号抄错了而这个错误要等到模型跑出荒谬结果后才被揪出来——这些场景背后真正拖慢你进度的往往不是算法本身而是求导这件事本身。而“Automatic Differentiation”自动微分常缩写为 AD就是那个能把你从符号推导、数值近似和梯度调试的泥潭里一把拽出来的工具。它既不是高等数学课本里用 ε-δ 定义的极限过程也不是用(f(xh)-f(x))/h这种粗糙差分来碰运气的数值方法它是一种精确、高效、可嵌入任意计算流程的程序化求导技术。我第一次在项目中把它从“论文里的概念”变成“我代码里的一行.backward()”时最大的震撼不是它多快而是它让我彻底忘了“链式法则怎么写”这回事——就像你开车时不会去想变速箱齿轮比AD 就是现代机器学习框架和科学计算库的底层变速箱。它不声不响地运行在 PyTorch、TensorFlow、JAX 的每一层 forward pass 之后把复杂的复合函数分解成一个个基本运算的微分规则再按计算图反向组装起来。这篇文章就是带你亲手拆开这个“变速箱”看清里面的齿轮怎么咬合、油路怎么走、哪些地方容易卡顿、以及当你需要自己造一个“小变速箱”比如写一个不依赖框架的微分器时该从哪颗螺丝开始拧。它面向的不是数学系的研究生而是每天和代码、数据、bug 打交道的工程师、研究员和进阶学习者。你不需要背下所有偏导公式但你需要知道当你的模型输出一个标量 loss调用.backward()的那一刻背后发生了什么以及当它没按你预期工作时你该往哪个方向去查。2. 核心设计思路为什么不用符号微分也不用数值微分2.1 三种求导方式的“能力-成本”光谱要真正理解自动微分的价值必须先把它放在一个更广阔的“求导方法家族”里看。这个家族里有三位主要成员符号微分Symbolic Differentiation、数值微分Numerical Differentiation和自动微分Automatic Differentiation。它们不是简单的“谁好谁坏”而是各自占据着一条清晰的“能力-成本”光谱适用于完全不同的战场。符号微分就像一位极其耐心的数学家。你给它一个表达式比如f(x) sin(x^2 cos(x))它会拿出一整套代数规则一步步推导最终给你一个同样漂亮的、解析的导数表达式f(x) cos(x^2 cos(x)) * (2x - sin(x))。它的优点是结果精确、形式优美、可直接用于进一步分析。但它的致命伤在于表达式膨胀Expression Swell。想象一下你有一个包含上百个变量、上千次运算的深度神经网络符号微分器会试图为整个网络生成一个单一的、巨大的、嵌套的解析导数公式。这个公式可能长到无法存储在内存里更别说进行任何实际计算了。它适合推导一个三行的物理公式但绝不适合处理一个拥有百万参数的 ResNet。数值微分则是一位务实的实验员。它不关心你函数的内部结构只做一件事在输入点x附近轻轻扰动一下比如加一个极小的h通常是1e-5或1e-8然后用差分公式(f(xh) - f(x)) / h来估算斜率。它的实现简单到只有两行代码对任何黑盒函数都有效。但它的缺点是精度与稳定性双重受限。h太大差分近似误差大h太小浮点数舍入误差会淹没掉真实的差分信号。更麻烦的是对于一个有n个输入的函数要算出完整的梯度向量它需要n1次函数求值前向模式或2n次中心差分。当你的模型有 100 万个参数时这意味着每次更新都要额外执行一百万次前向传播——这在计算上是完全不可接受的。自动微分则是前两者的“混血儿”但它完美规避了双方的缺陷。它不生成庞大的解析表达式也不进行不稳定的数值近似。它的核心思想是任何计算机程序无论多复杂最终都是由一系列基本的、已知其导数的原子操作如,-,*,/,sin,exp,log构成的。AD 的工作就是将你的原始程序源代码视为一个计算图然后在这个图上一边执行正向计算一边根据每个原子操作的“微分规则”同步地、精确地计算出导数。它得到的结果是数值上精确的和你用解析公式在相同浮点精度下计算出的结果一致并且计算复杂度与原函数的计算复杂度是同数量级的通常只是原函数的 3-5 倍而不是像数值微分那样随参数数量线性增长。提示你可以把 AD 想象成给你的代码编译器装了一个“微分插件”。当你写y x * x sin(x)时这个插件不仅会帮你算出y的值还会在后台同时生成并执行另一套指令专门用来计算dy/dx。它不是在“猜”也不是在“算一个近似”而是在“严格遵循链式法则一步不落地执行”。2.2 自动微分的两种实现范式前向模式与反向模式AD 并非一种单一的技术而是一套方法论其中最主流、也最实用的两种范式是前向模式Forward Mode和反向模式Reverse Mode。它们的区别本质上是链式法则应用顺序的不同而这直接决定了它们在不同场景下的效率。前向模式Forward Mode的核心是“边算边传”。它为每一个输入变量x_i都配一个对应的“切向量”tangent vectorẋ_i这个ẋ_i代表x_i的变化率例如如果你关心df/dx_1那么就设ẋ_1 1其余ẋ_i 0。在程序正向执行的每一步它不仅计算函数值y还同步计算该步输出关于输入的变化率ẏ。例如对于z x * y如果当前x2, y3, ẋ1, ẏ0那么z 6而ż ẋ*y x*ẏ 1*3 2*0 3这个ż就是dz/dx_1在当前点的值。前向模式的优点是实现直观、内存占用极小因为它只需要和原计算过程一样多的额外状态。但它的缺点是计算一个n维输入到m维输出的函数的完整雅可比矩阵需要n次独立的前向传播。当n很大比如神经网络的权重参数时这就不划算了。反向模式Reverse Mode则是“先记后算”。它首先进行一次完整的正向计算将所有中间变量的值以及它们之间的依赖关系即计算图完整地记录下来这个过程叫tape或trace。然后它从最终的输出通常是标量 loss开始逆着计算图的方向逐层应用链式法则将梯度adjoint从输出端“反向传播”回每一个输入端。这就是我们熟知的“反向传播Backpropagation”。它的优势是极致的效率对于一个n维输入到1维输出的函数这正是机器学习中损失函数的标准形式反向模式只需要一次正向传播 一次反向传播就能得到全部n个偏导数。这正是 PyTorch 和 TensorFlow 的心脏所在。它的代价是需要存储整个正向计算过程中的所有中间变量因此内存占用会显著增加。注意反向模式的高效性是有前提的即输出维度远小于输入维度m n。如果你的任务是计算一个1维输入到1000维输出的函数的雅可比矩阵例如一个传感器读数对一千个物理状态的影响那么前向模式反而更优因为它只需一次传播就能得到全部1000个导数。但在绝大多数深度学习场景中“loss 是标量”这个事实让反向模式成为了无可争议的王者。2.3 为什么现代框架都选择反向模式一个关于“计算图”的真相很多初学者会疑惑“既然前向模式内存小为什么 PyTorch 不默认用它”这个问题的答案藏在“计算图”这个概念的动态构建方式里。PyTorch 的autograd是动态图Dynamic Graph这意味着计算图是在 Python 代码运行时由torch.Tensor的每一次运算实时构建的。这种设计带来了无与伦比的灵活性你可以用if/else、for循环随意控制计算流但也带来了一个关键约束反向传播所需的“tape”必须在正向计算过程中被完整、准确地记录下来。这个记录过程就是torch.Tensor的grad_fn属性的由来。当你执行z x * y时z这个张量的grad_fn就会被设置为一个MulBackward0对象它内部封装了乘法运算的反向微分规则∂L/∂x ∂L/∂z * y∂L/∂y ∂L/∂z * x。整个网络的前向过程就是在不断创建这样一个由grad_fn节点组成的有向无环图DAG。当调用z.backward()时框架就从z开始沿着grad_fn指针递归地调用每个节点的backward()方法将梯度∂L/∂z一层层地传递下去。前向模式在动态图环境下实现起来要复杂得多。它要求在正向计算的每一步不仅要计算值还要为每一个潜在的输入变量维护一个切向量。这在 Python 的动态、灵活的语法下会极大地增加框架的实现复杂度和运行时开销。而反向模式恰好与“先执行、后回溯”的编程直觉高度吻合并且其内存开销存储中间变量在 GPU 显存充足的前提下是可以被接受的权衡。所以这不是一个理论上的优劣选择而是一个工程实践与领域需求深度耦合后的必然结果。当你在 PyTorch 中写下loss.backward()你调用的不是一个数学函数而是一个精心编排的、基于动态计算图的反向传播引擎。3. 核心细节解析从原理到代码手撕一个微型 AD 引擎3.1 最简实现一个支持加法和乘法的前向模式 AD 类为了彻底搞懂 AD 的“心跳”我们来亲手写一个最简化的前向模式 AD 实现。这不会是一个工业级的库但它会像一个透明的玻璃盒子让你看清每一个齿轮是如何转动的。我们将定义一个Variable类它不仅能存储一个数值val还能存储一个“切向量”der代表它对某个选定输入变量的变化率。class Variable: def __init__(self, val, der0.0): self.val float(val) self.der float(der) # 切向量初始为0 def __add__(self, other): if isinstance(other, Variable): # 新变量的值是两个值的和 new_val self.val other.val # 新变量的导数是两个导数的和链式法则d(uv)/dx du/dx dv/dx new_der self.der other.der return Variable(new_val, new_der) else: # 如果 other 是一个常数它的导数为0 new_val self.val float(other) new_der self.der # 常数的导数为0所以不改变 return Variable(new_val, new_der) def __mul__(self, other): if isinstance(other, Variable): new_val self.val * other.val # 乘积法则d(u*v)/dx u*dv/dx v*du/dx new_der self.val * other.der other.val * self.der return Variable(new_val, new_der) else: new_val self.val * float(other) new_der self.der * float(other) # 常数倍导数也倍增 return Variable(new_val, new_der) def __repr__(self): return fVariable(val{self.val:.3f}, der{self.der:.3f})现在让我们用它来计算一个经典例子f(x) x^2 2*x 1在x3处的导数f(3)。# 创建输入变量 x我们关心 df/dx所以设 x 的切向量为 1 x Variable(3.0, der1.0) # 计算 f(x) x^2 2*x 1 x_squared x * x # Variable(val9.000, der6.000) 因为 d(x^2)/dx 2x 6 two_x Variable(2.0) * x # Variable(val6.000, der2.000) 因为 d(2x)/dx 2 one Variable(1.0, der0.0) # 常数导数为0 f x_squared two_x one # Variable(val16.000, der8.000) print(ff(3) {f.val}) # 输出: f(3) 16.000 print(ff(3) {f.der}) # 输出: f(3) 8.000这个结果是完美的。f(x) 2x 2所以f(3) 8。我们的微型引擎给出了精确的数值结果。关键在于x.der 1这个设定相当于告诉引擎“请计算所有东西关于x的变化率”。在x * x这一步引擎没有去解一个方程而是直接应用了早已硬编码在__mul__方法里的乘积法则。这就是 AD 的精髓将数学规则固化在程序逻辑中让计算过程本身成为求导过程。实操心得我第一次写这个Variable类时在__add__方法里漏掉了isinstance(other, Variable)的判断导致x 2这样的操作直接报错。这提醒我AD 引擎的健壮性很大程度上取决于它对“混合类型”变量与常数的处理是否周全。在真实框架中这种类型检查和转换Type Promotion是极其复杂的它确保了torch.tensor([1,2,3]) 5和torch.tensor([1,2,3]) torch.tensor(5)能得到完全一致的结果。3.2 反向模式的核心计算图与backward函数的构造前向模式清晰易懂但要理解现代框架的“灵魂”我们必须升级到反向模式。它的核心不再是“切向量”而是“伴随变量”adjoint也就是我们常说的“梯度”。我们将构建一个更抽象的Node类它代表计算图中的一个节点它知道自己是怎么被创建的op它的输入是什么children以及最重要的——当梯度∂L/∂node流到它这里时它该如何将梯度分配给自己的每一个子节点。from typing import List, Callable, Any class Node: def __init__(self, value: float, children: List[Node] None, op: str ): self.value value self.children children or [] self.op op self.grad 0.0 # 初始化梯度为0 # 这个函数将在反向传播时被调用用于计算并累加梯度到子节点 self._backward lambda: None def __add__(self, other): if not isinstance(other, Node): other Node(other) out Node(self.value other.value, [self, other], ) # 定义反向传播函数加法的梯度是恒等映射 # dL/dself dL/dout * 1, dL/other dL/dout * 1 def _backward(): self.grad out.grad other.grad out.grad out._backward _backward return out def __mul__(self, other): if not isinstance(other, Node): other Node(other) out Node(self.value * other.value, [self, other], *) # 定义反向传播函数乘法的梯度是乘积法则 # dL/dself dL/dout * other.value, dL/other dL/dout * self.value def _backward(): self.grad out.grad * other.value other.grad out.grad * self.value out._backward _backward return out def backward(self): # 构建拓扑排序确保子节点在父节点之前被处理 topo [] visited set() def build_topo(v): if v not in visited: visited.add(v) for child in v.children: build_topo(child) topo.append(v) build_topo(self) # 将输出节点的梯度设为1因为我们计算的是 dL/dL self.grad 1.0 # 逆序遍历拓扑序执行每个节点的 _backward 函数 for node in reversed(topo): node._backward()现在我们用它来复现之前的例子但这次用反向模式# 创建输入节点 x Node(3.0) y Node(2.0) # 我们也可以把常数当作节点 # 构建计算图: f x*x 2*x 1 x_squared x * x two_x y * x one Node(1.0) f x_squared two_x one # 执行反向传播 f.backward() print(ff {f.value}) # f 16.0 print(fdf/dx {x.grad}) # df/dx 8.0 print(fdf/dy {y.grad}) # df/dy 3.0 (因为 f x^2 y*x 1, 所以 df/dy x 3)这段代码的魔力在于backward()方法。它首先通过深度优先搜索DFS构建了一个拓扑排序topological order这个排序保证了在处理任何一个节点时它的所有子节点即它的“上游”依赖都已经被处理过了。然后它将f.grad设为1.0因为df/df 1最后它逆序遍历这个排序对每个节点调用其_backward函数。这个_backward函数就是我们在__add__和__mul__中定义的、针对该运算的微分规则。它不计算新的值只负责将流入该节点的梯度按照正确的数学规则分配accumulate到它的子节点上。注意self.grad ...中的是至关重要的。它实现了梯度的累加。在复杂的计算图中一个中间变量可能被多个下游节点所依赖例如一个 ReLU 激活后的特征图会被后续的卷积和池化共同使用它的梯度是所有这些路径贡献的总和。确保了这一点而不仅仅是。3.3 PyTorchautograd的真实世界接口torch.func.grad与torch.compile上面的手写代码是教学用的“玩具”而 PyTorch 的autograd是一个经过十年千锤百炼的工业级引擎。理解它的高级接口能让你在实际项目中如鱼得水。torch.autograd.grad这是最底层、最灵活的接口。它不修改任何张量的.grad属性而是直接返回你指定的梯度张量。这在实现复杂的优化算法如二阶优化、元学习时非常有用。import torch x torch.tensor(3.0, requires_gradTrue) y torch.tensor(2.0, requires_gradTrue) f x * x y * x 1 # 计算 df/dx 和 df/dy不修改 x.grad, y.grad grads torch.autograd.grad(outputsf, inputs[x, y], retain_graphTrue) print(fdf/dx {grads[0]}) # tensor(8.) print(fdf/dy {grads[1]}) # tensor(3.)torch.func.grad推荐这是 PyTorch 2.0 引入的、面向函数式编程的新一代 API。它将“求导”本身变成了一个高阶函数可以像装饰器一样使用代码更加清晰、安全且天然支持torch.compile。from torch.func import grad def f(x, y): return x * x y * x 1 # 创建一个新函数它接受 x, y 并返回 df/dx df_dx grad(f, argnums0) # 创建一个新函数它接受 x, y 并返回 (df/dx, df/dy) df_dx_dy grad(f, argnums(0, 1)) x_t torch.tensor(3.0) y_t torch.tensor(2.0) print(df_dx(x_t, y_t)) # tensor(8.) print(df_dx_dy(x_t, y_t)) # (tensor(8.), tensor(3.))torch.compile这是 PyTorch 2.0 的另一个重磅特性它能将包含autograd的整个计算图包括前向和反向一起编译优化。它不只是加速前向更是将反向传播的计算图也一并优化有时能带来 2-3 倍的整体训练速度提升。启用它往往只需要一行代码# 将你的模型和损失函数包装起来 compiled_model torch.compile(model) compiled_loss_fn torch.compile(loss_fn) # 在训练循环中使用 for x, y in dataloader: y_pred compiled_model(x) loss compiled_loss_fn(y_pred, y) loss.backward() # 这里的 backward 也是被编译优化过的实操心得我在一个图像分割项目中将torch.compile应用到一个基于nn.Module的自定义损失函数上结果发现训练速度提升了 40%而且显存占用反而下降了。这是因为compile不仅优化了计算还智能地重用了中间变量的内存。但要注意compile对模型结构有一定要求过于动态的控制流如for i in range(torch.randint(1, 10, (1,)))可能会让它失效。我的经验是先用torch.compile(..., modereduce-overhead)进行轻量级优化再逐步升级到max-autotune。4. 实操过程在真实项目中驾驭自动微分的全流程4.1 场景一调试一个“梯度消失”的 RNN 模型RNN 是自动微分的“试金石”因为它的计算图在时间维度上是展开的梯度需要穿越数十甚至上百个时间步。当loss.backward()执行完毕后你发现model.hidden.weight_hh.grad几乎全是零这就是典型的“梯度消失”。此时AD 不是问题而是你的诊断工具。第一步可视化梯度流。不要只看最终的.grad要追踪梯度在时间步上的衰减。# 在训练循环中记录每个时间步隐藏状态的梯度范数 hidden_states [] # 存储 forward 过程中的所有 hidden_state losses [] for t in range(seq_len): h_t model.rnn_cell(x_t[:, t, :], h_prev) hidden_states.append(h_t) h_prev h_t loss compute_loss(hidden_states[-1], target) loss.backward() # 分析梯度 grad_norms [] for h in hidden_states: # 获取该 hidden_state 的梯度它是一个中间变量需要特殊处理 if h.grad is not None: grad_norms.append(h.grad.norm().item()) else: # 如果是中间变量它的 grad 可能为 None我们需要用 autograd.grad # 这里简化假设我们已经用 retain_graphTrue 保留了图 grad_norms.append(0.0) print(Gradient norms over time:, grad_norms) # 输出可能是: [1.2, 0.8, 0.5, 0.3, 0.1, 0.05, 0.01, ...] —— 明显指数衰减第二步定位“罪魁祸首”。梯度消失通常源于激活函数如tanh的饱和区或权重矩阵的奇异值过大。我们可以用torch.autograd.functional.jacobian来计算一个时间步内h_{t}关于h_{t-1}的雅可比矩阵并检查其最大奇异值spectral norm。from torch.autograd.functional import jacobian # 计算单步雅可比 J dh_t / dh_{t-1} def step_func(h_prev): return model.rnn_cell(x_t[:, 0, :], h_prev) J jacobian(step_func, h_prev) # J 的形状是 [batch, hidden_dim, batch, hidden_dim] # 我们需要的是每个样本的雅可比取第一个样本 J_sample J[0, :, 0, :] # [hidden_dim, hidden_dim] # 计算谱范数最大奇异值 import torch.linalg as LA spectral_norm LA.svdvals(J_sample).max().item() print(fSpectral norm of Jacobian: {spectral_norm}) # 如果这个值远小于 1比如 0.1说明梯度在每一步都会被压缩 10 倍10 步后就衰减到 1e-10 了。第三步修复。知道了原因解决方案就很明确了换用ReLU或LSTM/GRU单元它们内置了门控机制来缓解此问题或者对weight_hh进行正交初始化torch.nn.init.orthogonal_确保其雅可比矩阵接近正交矩阵谱范数接近 1。注意jacobian是一个计算开销很大的操作只应在调试时使用绝不能放在训练循环里。它会为每个输入维度都执行一次前向传播对于一个 1000 维的隐藏层这意味着 1000 次前向传播4.2 场景二实现一个自定义的、可微分的物理约束损失在机器人控制或物理仿真中我们经常需要将物理定律如能量守恒、运动学约束作为损失函数的一部分。这些约束往往是隐式的、非线性的无法用标准的nn.Module表达。这时AD 的强大之处就体现出来了只要你的约束函数是用可微分的 PyTorch 操作写的autograd就能自动为你求导。假设我们要训练一个控制器使其输出的关节角度q满足一个复杂的运动学约束C(q) 0例如末端执行器必须保持在某个平面上。我们可以定义一个损失def kinematic_constraint_loss(q: torch.Tensor, target_plane_normal: torch.Tensor, target_plane_offset: float) - torch.Tensor: q: [batch, n_dof] 关节角度 计算末端执行器位置 p(q)然后计算 p 到目标平面的距离的平方 # 这里是你的正向运动学函数用 torch.sin, torch.cos, torch.matmul 等实现 p forward_kinematics(q) # p: [batch, 3] # 平面距离公式: |n·p d| / ||n||, 这里我们用平方来避免开根号 distance_sq (torch.sum(p * target_plane_normal, dim-1) target_plane_offset) ** 2 return distance_sq.mean() # 返回标量 loss # 在训练循环中 q_pred controller(obs) loss_kin kinematic_constraint_loss(q_pred, plane_n, plane_d) loss_total loss_task 0.1 * loss_kin # 加权 loss_total.backward() # autograd 会自动穿透 forward_kinematics 的所有三角函数和矩阵运算forward_kinematics函数可能包含几十行torch.sin(q1) * torch.cos(q2)这样的运算但你完全不需要为它手动推导雅可比矩阵。autograd会自动将sin的导数cos、cos的导数-sin、矩阵乘法的导数规则全部无缝地组合起来。这极大地解放了你的生产力让你可以把精力集中在物理建模本身而不是繁琐的微分计算上。实操心得我曾经在一个无人机姿态控制项目中用这种方式实现了“姿态四元数必须保持单位长度”的约束。最初我用q.norm(dim-1) - 1作为损失结果发现梯度在q接近零时不稳定。后来改用(q.norm(dim-1) ** 2 - 1) ** 2即“单位长度误差的平方”梯度就变得非常平滑。这说明即使有 AD损失函数的设计尤其是其曲率依然至关重要。AD 给你的是精确的梯度但不保证这个梯度能引导你走向一个好的解。4.3 场景三使用torch.func.vjp实现高效的 Hessian-Vector Product在二阶优化如牛顿法或某些元学习算法中我们不需要完整的海森矩阵Hessian而只需要它与一个向量v的乘积Hv。计算完整的H是 O(n²) 的而Hv可以通过两次反向传播在 O(n) 时间内完成。这就是vjpVector-Jacobian Product的用武之地。from torch.func import vjp def loss_fn(params): # params 是一个包含所有模型参数的 tuple # 这里执行前向传播和 loss 计算 model.set_params(params) # 假设你有这样一个方法 y_pred model(x_batch) return loss_fn(y_pred, y_batch) # 获取 loss 关于 params 的梯度Jacobian _, vjpfunc vjp(loss_fn, params) # v 是一个与 params 结构相同的 tuple代表你要相乘的向量 # 例如v 可以是当前的梯度用于自然梯度下降 hvp vjpfunc(v) # 这就是 Hv # hvp 的结构与 params 完全相同可以直接用于参数更新vjp的工作原理是vjp(loss_fn, params)返回一个函数vjpfunc当你把一个向量v传给它时它会执行一次反向传播但这次不是用1.0作为输出梯度而是用你提供的v。这相当于在计算v^T * J其中J是loss_fn的雅可比矩阵。如果你再对vjpfunc的结果它本身也是一个函数进行一次vjp你就得到了Hv。这是一种非常优雅且高效的“嵌套微分”技巧充分展现了 AD 作为“可微分编程”范式的威力。5. 常见问题与排查技巧实录那些年踩过的坑5.1 “RuntimeError: Trying to backward through the graph a second time” ——retain_graph的迷思这是新手遇到的第一个“拦路虎”。当你第一次调用loss.backward()后PyTorch 默认会释放计算图graph以节省内存。如果你紧接着又调用了一次loss.backward()就会触发这个错误。错误做法loss.backward() # 第一次成功 loss.backward() # 第二次报错正确做法loss.backward(retain_graphTrue) # 第一次保留图 loss.backward() # 第二次可以继续用 # 或者更常见的是只在需要多次 backward 时才 retain loss1.backward(retain_graphTrue) loss2.backward() # loss2 可能依赖 loss1 的一些中间变量注意retain_graphTrue会阻止中间变量被释放导致内存占用翻倍。所以只在绝对必要时才使用它。一个更优雅的替代方案是用torch.autograd.grad来获取梯度因为它不修改.grad属性也不会破坏图。5.2 “

相关新闻